feature: raise on long sequence drop (#3321)

* feature: raise on long sequence drop

It is sometimes not desired that sequences are silently dropped from the dataset, especially when the dataset has been carefully crafted and pre-fitted for the training context. This would then suggest that an error occurred somewhere in the process. This feature adds a third value for excess_length_strategy called 'raise', which will raise a ValueError if a sequence is encountered that is too long and would have normally been dropped/truncated.

* tests: add excess_length_strategy tests

* doc: updated return value description for drop_long_seq_in_dataset

* add @enable_hf_offline

* fixed cfg modified after validate_config called

* hf offline fix

* fix tqdm desc when raise is used

* test: added test for non-batched case

* accidental code change revert

* test: use pytest.raises

* test: simplified drop_seq_len tests

* test: moved excess_length_strat test to test_data.py

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
This commit is contained in:
kallewoof
2025-12-23 03:59:49 +09:00
committed by GitHub
parent efeb5a4e41
commit 92ee4256f7
5 changed files with 67 additions and 7 deletions

View File

@@ -188,7 +188,10 @@ def handle_long_seq_in_dataset(
cfg: Dictionary mapping `axolotl` config keys to values.
Returns:
Filtered dataset with long sequences removed.
Filtered dataset with long sequences handled according to the excess_length_strategy value:
'drop' (default) excludes any sequence longer than sequence_len
'truncate' truncates them down to sequence_len
'raise' raises a ValueError if any sequence was found that was longer than sequence_len
"""
if (
hasattr(dataset, "column_names")
@@ -206,10 +209,13 @@ def handle_long_seq_in_dataset(
)
return dataset
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
drop_long = functools.partial(
drop_long_seq,
sequence_len=sequence_len,
min_sequence_len=cfg.min_sample_len,
raise_on_drop=excess_length_strategy == "raise",
)
with contextlib.suppress(AttributeError):
@@ -228,9 +234,13 @@ def handle_long_seq_in_dataset(
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
action = (
"Checking Sequence Lengths"
if excess_length_strategy == "raise"
else "Dropping Long Sequences"
)
drop_long_kwargs["desc"] = f"{action} (>{sequence_len})"
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
if excess_length_strategy == "truncate":
process_fn = functools.partial(
truncate_long_seq,

View File

@@ -452,10 +452,10 @@ class AxolotlInputConfig(
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
},
)
excess_length_strategy: Literal["drop", "truncate"] | None = Field(
excess_length_strategy: Literal["drop", "truncate", "raise"] | None = Field(
default=None,
json_schema_extra={
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility."
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len; 'raise' raises a ValueError. Defaults to 'drop' for backward compatibility."
},
)
eval_sequence_len: int | None = Field(

View File

@@ -205,12 +205,15 @@ def add_length(sample):
return sample
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False):
"""
Drop samples whose sequence length is either too long (> sequence_len)
or too short (< min_sequence_len).
Works for both single-example (list[int]) or batched (list[list[int]]).
If raise_on_drop is set, the code raises a ValueError if a sample is
encountered that is too long and would have been dropped.
"""
min_sequence_len = min_sequence_len or 2
@@ -225,12 +228,20 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
if isinstance(input_ids[0], int):
# Single example (input_ids is a list of int)
length = len(input_ids)
if raise_on_drop and length > sequence_len:
raise ValueError(
f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}."
)
return min_sequence_len <= length <= sequence_len
# Batched (input_ids is a list of lists)
results = []
for seq in input_ids:
length = len(seq)
if raise_on_drop and length > sequence_len:
raise ValueError(
f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}."
)
results.append(min_sequence_len <= length <= sequence_len)
return results