diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 2d0ca9d0e..319e27f6f 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -188,7 +188,10 @@ def handle_long_seq_in_dataset( cfg: Dictionary mapping `axolotl` config keys to values. Returns: - Filtered dataset with long sequences removed. + Filtered dataset with long sequences handled according to the excess_length_strategy value: + 'drop' (default) excludes any sequence longer than sequence_len + 'truncate' truncates them down to sequence_len + 'raise' raises a ValueError if any sequence was found that was longer than sequence_len """ if ( hasattr(dataset, "column_names") @@ -206,10 +209,13 @@ def handle_long_seq_in_dataset( ) return dataset + excess_length_strategy = (cfg.excess_length_strategy or "drop").lower() + drop_long = functools.partial( drop_long_seq, sequence_len=sequence_len, min_sequence_len=cfg.min_sample_len, + raise_on_drop=excess_length_strategy == "raise", ) with contextlib.suppress(AttributeError): @@ -228,9 +234,13 @@ def handle_long_seq_in_dataset( drop_long_kwargs = {} if filter_map_kwargs: - drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})" + action = ( + "Checking Sequence Lengths" + if excess_length_strategy == "raise" + else "Dropping Long Sequences" + ) + drop_long_kwargs["desc"] = f"{action} (>{sequence_len})" - excess_length_strategy = (cfg.excess_length_strategy or "drop").lower() if excess_length_strategy == "truncate": process_fn = functools.partial( truncate_long_seq, diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index e0c9acd4d..f2f4a311a 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -452,10 +452,10 @@ class AxolotlInputConfig( "description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048" }, ) - excess_length_strategy: Literal["drop", "truncate"] | None = Field( + excess_length_strategy: Literal["drop", "truncate", "raise"] | None = Field( default=None, json_schema_extra={ - "description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility." + "description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len; 'raise' raises a ValueError. Defaults to 'drop' for backward compatibility." }, ) eval_sequence_len: int | None = Field( diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index d97577d86..3628fd85f 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -205,12 +205,15 @@ def add_length(sample): return sample -def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): +def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False): """ Drop samples whose sequence length is either too long (> sequence_len) or too short (< min_sequence_len). Works for both single-example (list[int]) or batched (list[list[int]]). + + If raise_on_drop is set, the code raises a ValueError if a sample is + encountered that is too long and would have been dropped. """ min_sequence_len = min_sequence_len or 2 @@ -225,12 +228,20 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): if isinstance(input_ids[0], int): # Single example (input_ids is a list of int) length = len(input_ids) + if raise_on_drop and length > sequence_len: + raise ValueError( + f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}." + ) return min_sequence_len <= length <= sequence_len # Batched (input_ids is a list of lists) results = [] for seq in input_ids: length = len(seq) + if raise_on_drop and length > sequence_len: + raise ValueError( + f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}." + ) results.append(min_sequence_len <= length <= sequence_len) return results diff --git a/tests/test_data.py b/tests/test_data.py index 99ed06336..ad76bbf6e 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -7,6 +7,7 @@ import unittest from transformers import LlamaTokenizer from axolotl.utils.data import encode_streaming, md5 +from axolotl.utils.trainer import drop_long_seq from tests.hf_offline_utils import enable_hf_offline @@ -63,6 +64,42 @@ class TestEncodePretraining(unittest.TestCase): md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3" ) + def test_excess_length_strategy(self): + """Test that excess_length_strategy results in a value error when set to 'raise'.""" + + # -- single sequence -- + # This should work + data = {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]} + drop_long_seq(data, 32, raise_on_drop=True) + + # This should return True, since data fits + dropped = drop_long_seq(data, 32) + self.assertTrue(dropped) + + # This should raise + self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True) + + # This should return False, since data doesn't fit + dropped = drop_long_seq(data, 15) + self.assertFalse(dropped) + + # -- batch sequence -- + # This should work + data = { + "input_ids": [ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ] + } + drop_long_seq(data, 32, raise_on_drop=True) + + # This should raise + self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True) + + # This should keep the first but drop the second entry + dropped = drop_long_seq(data, 15) + self.assertEqual(dropped, [True, False]) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_datasets.py b/tests/test_datasets.py index bd1c8f2c2..3b24ad580 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -13,7 +13,9 @@ from transformers import PreTrainedTokenizer from axolotl.loaders.tokenizer import load_tokenizer from axolotl.utils.data.rl import prepare_preference_datasets -from axolotl.utils.data.sft import _load_tokenized_prepared_datasets +from axolotl.utils.data.sft import ( + _load_tokenized_prepared_datasets, +) from axolotl.utils.dict import DictDefault from tests.constants import (