Compare commits

...

1 Commits

Author SHA1 Message Date
coderabbitai[bot]
0fccbadb79 📝 Add docstrings to 202512-raise_on_drop
Docstrings generation was requested by @kallewoof.

* https://github.com/axolotl-ai-cloud/axolotl/pull/3321#issuecomment-3668489902

The following files were modified:

* `src/axolotl/utils/data/utils.py`
* `src/axolotl/utils/trainer.py`
2025-12-18 05:49:01 +00:00
2 changed files with 48 additions and 16 deletions

View File

@@ -180,15 +180,20 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def handle_long_seq_in_dataset( def handle_long_seq_in_dataset(
dataset: Dataset, sequence_len: int, cfg: DictDefault dataset: Dataset, sequence_len: int, cfg: DictDefault
) -> Dataset: ) -> Dataset:
"""Remove sequences longer than configured maximum from dataset. """
Remove or truncate sequences that exceed the configured maximum length from a dataset.
Args: Parameters:
dataset: Dataset to filter. dataset (Dataset): Dataset to process; if it lacks an "input_ids" column or is streaming, it is returned unchanged.
sequence_len: Maximum length for sequences to keep sequence_len (int): Maximum allowed sequence length; sequences longer than this are either removed or truncated.
cfg: Dictionary mapping `axolotl` config keys to values. cfg (DictDefault): Configuration object with keys:
- excess_length_strategy: "drop", "truncate", or "raise" — determines how to handle overlong sequences.
- min_sample_len: minimum allowed sequence length (used when truncating or dropping).
- dataset_num_proc: number of processes to use for non-streaming datasets.
- is_preprocess: when true, bypasses cached preprocessing during filtering.
Returns: Returns:
Filtered dataset with long sequences removed. Dataset: The input dataset with sequences longer than `sequence_len` removed or truncated according to `cfg`.
""" """
if ( if (
hasattr(dataset, "column_names") hasattr(dataset, "column_names")
@@ -206,10 +211,13 @@ def handle_long_seq_in_dataset(
) )
return dataset return dataset
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
drop_long = functools.partial( drop_long = functools.partial(
drop_long_seq, drop_long_seq,
sequence_len=sequence_len, sequence_len=sequence_len,
min_sequence_len=cfg.min_sample_len, min_sequence_len=cfg.min_sample_len,
raise_on_drop=excess_length_strategy == "raise",
) )
with contextlib.suppress(AttributeError): with contextlib.suppress(AttributeError):
@@ -230,7 +238,6 @@ def handle_long_seq_in_dataset(
if filter_map_kwargs: if filter_map_kwargs:
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})" drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
if excess_length_strategy == "truncate": if excess_length_strategy == "truncate":
process_fn = functools.partial( process_fn = functools.partial(
truncate_long_seq, truncate_long_seq,

View File

@@ -201,16 +201,33 @@ def add_pose_position_ids(
def add_length(sample): def add_length(sample):
"""
Set the "length" field on a sample to the number of input tokens.
Parameters:
sample (Mapping-like): A sample containing an "input_ids" sequence.
Returns:
sample (dict-like): The same sample with "length" set to len(sample["input_ids"]).
"""
sample["length"] = len(sample["input_ids"]) sample["length"] = len(sample["input_ids"])
return sample return sample
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False):
""" """
Drop samples whose sequence length is either too long (> sequence_len) Return whether a sample (single or batched) should be kept based on sequence length constraints.
or too short (< min_sequence_len).
Works for both single-example (list[int]) or batched (list[list[int]]). Determines if each sequence's length falls within [min_sequence_len, sequence_len]. Supports a single example (list[int]) or a batch (list[list[int]]). If the sample's "input_ids" is empty, the sample is treated as dropped. When raise_on_drop is True, encountering any sequence longer than sequence_len raises a ValueError.
Parameters:
sample (dict): A mapping containing "input_ids" with either a single sequence or a batch of sequences.
sequence_len (int): Maximum allowed sequence length (inclusive).
min_sequence_len (int): Minimum allowed sequence length (inclusive).
raise_on_drop (bool): If True, raise ValueError when a sequence exceeds sequence_len.
Returns:
bool or list[bool]: For a single example, returns True if its length is within the bounds, False otherwise. For a batch, returns a list of booleans indicating which sequences should be kept.
""" """
min_sequence_len = min_sequence_len or 2 min_sequence_len = min_sequence_len or 2
@@ -225,12 +242,20 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
if isinstance(input_ids[0], int): if isinstance(input_ids[0], int):
# Single example (input_ids is a list of int) # Single example (input_ids is a list of int)
length = len(input_ids) length = len(input_ids)
if raise_on_drop and length > sequence_len:
raise ValueError(
f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}."
)
return min_sequence_len <= length <= sequence_len return min_sequence_len <= length <= sequence_len
# Batched (input_ids is a list of lists) # Batched (input_ids is a list of lists)
results = [] results = []
for seq in input_ids: for seq in input_ids:
length = len(seq) length = len(seq)
if raise_on_drop and length > sequence_len:
raise ValueError(
f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}."
)
results.append(min_sequence_len <= length <= sequence_len) results.append(min_sequence_len <= length <= sequence_len)
return results return results