feature: raise on long sequence drop (#3321)
* feature: raise on long sequence drop It is sometimes not desired that sequences are silently dropped from the dataset, especially when the dataset has been carefully crafted and pre-fitted for the training context. This would then suggest that an error occurred somewhere in the process. This feature adds a third value for excess_length_strategy called 'raise', which will raise a ValueError if a sequence is encountered that is too long and would have normally been dropped/truncated. * tests: add excess_length_strategy tests * doc: updated return value description for drop_long_seq_in_dataset * add @enable_hf_offline * fixed cfg modified after validate_config called * hf offline fix * fix tqdm desc when raise is used * test: added test for non-batched case * accidental code change revert * test: use pytest.raises * test: simplified drop_seq_len tests * test: moved excess_length_strat test to test_data.py --------- Co-authored-by: salman <salman.mohammadi@outlook.com>
This commit is contained in:
@@ -188,7 +188,10 @@ def handle_long_seq_in_dataset(
|
|||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Filtered dataset with long sequences removed.
|
Filtered dataset with long sequences handled according to the excess_length_strategy value:
|
||||||
|
'drop' (default) excludes any sequence longer than sequence_len
|
||||||
|
'truncate' truncates them down to sequence_len
|
||||||
|
'raise' raises a ValueError if any sequence was found that was longer than sequence_len
|
||||||
"""
|
"""
|
||||||
if (
|
if (
|
||||||
hasattr(dataset, "column_names")
|
hasattr(dataset, "column_names")
|
||||||
@@ -206,10 +209,13 @@ def handle_long_seq_in_dataset(
|
|||||||
)
|
)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
||||||
|
|
||||||
drop_long = functools.partial(
|
drop_long = functools.partial(
|
||||||
drop_long_seq,
|
drop_long_seq,
|
||||||
sequence_len=sequence_len,
|
sequence_len=sequence_len,
|
||||||
min_sequence_len=cfg.min_sample_len,
|
min_sequence_len=cfg.min_sample_len,
|
||||||
|
raise_on_drop=excess_length_strategy == "raise",
|
||||||
)
|
)
|
||||||
|
|
||||||
with contextlib.suppress(AttributeError):
|
with contextlib.suppress(AttributeError):
|
||||||
@@ -228,9 +234,13 @@ def handle_long_seq_in_dataset(
|
|||||||
|
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
action = (
|
||||||
|
"Checking Sequence Lengths"
|
||||||
|
if excess_length_strategy == "raise"
|
||||||
|
else "Dropping Long Sequences"
|
||||||
|
)
|
||||||
|
drop_long_kwargs["desc"] = f"{action} (>{sequence_len})"
|
||||||
|
|
||||||
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
|
||||||
if excess_length_strategy == "truncate":
|
if excess_length_strategy == "truncate":
|
||||||
process_fn = functools.partial(
|
process_fn = functools.partial(
|
||||||
truncate_long_seq,
|
truncate_long_seq,
|
||||||
|
|||||||
@@ -452,10 +452,10 @@ class AxolotlInputConfig(
|
|||||||
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
excess_length_strategy: Literal["drop", "truncate"] | None = Field(
|
excess_length_strategy: Literal["drop", "truncate", "raise"] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility."
|
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len; 'raise' raises a ValueError. Defaults to 'drop' for backward compatibility."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
eval_sequence_len: int | None = Field(
|
eval_sequence_len: int | None = Field(
|
||||||
|
|||||||
@@ -205,12 +205,15 @@ def add_length(sample):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False):
|
||||||
"""
|
"""
|
||||||
Drop samples whose sequence length is either too long (> sequence_len)
|
Drop samples whose sequence length is either too long (> sequence_len)
|
||||||
or too short (< min_sequence_len).
|
or too short (< min_sequence_len).
|
||||||
|
|
||||||
Works for both single-example (list[int]) or batched (list[list[int]]).
|
Works for both single-example (list[int]) or batched (list[list[int]]).
|
||||||
|
|
||||||
|
If raise_on_drop is set, the code raises a ValueError if a sample is
|
||||||
|
encountered that is too long and would have been dropped.
|
||||||
"""
|
"""
|
||||||
min_sequence_len = min_sequence_len or 2
|
min_sequence_len = min_sequence_len or 2
|
||||||
|
|
||||||
@@ -225,12 +228,20 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
|||||||
if isinstance(input_ids[0], int):
|
if isinstance(input_ids[0], int):
|
||||||
# Single example (input_ids is a list of int)
|
# Single example (input_ids is a list of int)
|
||||||
length = len(input_ids)
|
length = len(input_ids)
|
||||||
|
if raise_on_drop and length > sequence_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}."
|
||||||
|
)
|
||||||
return min_sequence_len <= length <= sequence_len
|
return min_sequence_len <= length <= sequence_len
|
||||||
|
|
||||||
# Batched (input_ids is a list of lists)
|
# Batched (input_ids is a list of lists)
|
||||||
results = []
|
results = []
|
||||||
for seq in input_ids:
|
for seq in input_ids:
|
||||||
length = len(seq)
|
length = len(seq)
|
||||||
|
if raise_on_drop and length > sequence_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}."
|
||||||
|
)
|
||||||
results.append(min_sequence_len <= length <= sequence_len)
|
results.append(min_sequence_len <= length <= sequence_len)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import unittest
|
|||||||
from transformers import LlamaTokenizer
|
from transformers import LlamaTokenizer
|
||||||
|
|
||||||
from axolotl.utils.data import encode_streaming, md5
|
from axolotl.utils.data import encode_streaming, md5
|
||||||
|
from axolotl.utils.trainer import drop_long_seq
|
||||||
|
|
||||||
from tests.hf_offline_utils import enable_hf_offline
|
from tests.hf_offline_utils import enable_hf_offline
|
||||||
|
|
||||||
@@ -63,6 +64,42 @@ class TestEncodePretraining(unittest.TestCase):
|
|||||||
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
|
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_excess_length_strategy(self):
|
||||||
|
"""Test that excess_length_strategy results in a value error when set to 'raise'."""
|
||||||
|
|
||||||
|
# -- single sequence --
|
||||||
|
# This should work
|
||||||
|
data = {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]}
|
||||||
|
drop_long_seq(data, 32, raise_on_drop=True)
|
||||||
|
|
||||||
|
# This should return True, since data fits
|
||||||
|
dropped = drop_long_seq(data, 32)
|
||||||
|
self.assertTrue(dropped)
|
||||||
|
|
||||||
|
# This should raise
|
||||||
|
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
|
||||||
|
|
||||||
|
# This should return False, since data doesn't fit
|
||||||
|
dropped = drop_long_seq(data, 15)
|
||||||
|
self.assertFalse(dropped)
|
||||||
|
|
||||||
|
# -- batch sequence --
|
||||||
|
# This should work
|
||||||
|
data = {
|
||||||
|
"input_ids": [
|
||||||
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
|
||||||
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
|
||||||
|
]
|
||||||
|
}
|
||||||
|
drop_long_seq(data, 32, raise_on_drop=True)
|
||||||
|
|
||||||
|
# This should raise
|
||||||
|
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
|
||||||
|
|
||||||
|
# This should keep the first but drop the second entry
|
||||||
|
dropped = drop_long_seq(data, 15)
|
||||||
|
self.assertEqual(dropped, [True, False])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -13,7 +13,9 @@ from transformers import PreTrainedTokenizer
|
|||||||
|
|
||||||
from axolotl.loaders.tokenizer import load_tokenizer
|
from axolotl.loaders.tokenizer import load_tokenizer
|
||||||
from axolotl.utils.data.rl import prepare_preference_datasets
|
from axolotl.utils.data.rl import prepare_preference_datasets
|
||||||
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets
|
from axolotl.utils.data.sft import (
|
||||||
|
_load_tokenized_prepared_datasets,
|
||||||
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
from tests.constants import (
|
from tests.constants import (
|
||||||
|
|||||||
Reference in New Issue
Block a user