diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 2d0ca9d0e..68eb57935 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -180,15 +180,20 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2): def handle_long_seq_in_dataset( dataset: Dataset, sequence_len: int, cfg: DictDefault ) -> Dataset: - """Remove sequences longer than configured maximum from dataset. - - Args: - dataset: Dataset to filter. - sequence_len: Maximum length for sequences to keep - cfg: Dictionary mapping `axolotl` config keys to values. - + """ + Remove or truncate sequences that exceed the configured maximum length from a dataset. + + Parameters: + dataset (Dataset): Dataset to process; if it lacks an "input_ids" column or is streaming, it is returned unchanged. + sequence_len (int): Maximum allowed sequence length; sequences longer than this are either removed or truncated. + cfg (DictDefault): Configuration object with keys: + - excess_length_strategy: "drop", "truncate", or "raise" — determines how to handle overlong sequences. + - min_sample_len: minimum allowed sequence length (used when truncating or dropping). + - dataset_num_proc: number of processes to use for non-streaming datasets. + - is_preprocess: when true, bypasses cached preprocessing during filtering. + Returns: - Filtered dataset with long sequences removed. + Dataset: The input dataset with sequences longer than `sequence_len` removed or truncated according to `cfg`. """ if ( hasattr(dataset, "column_names") @@ -206,10 +211,13 @@ def handle_long_seq_in_dataset( ) return dataset + excess_length_strategy = (cfg.excess_length_strategy or "drop").lower() + drop_long = functools.partial( drop_long_seq, sequence_len=sequence_len, min_sequence_len=cfg.min_sample_len, + raise_on_drop=excess_length_strategy == "raise", ) with contextlib.suppress(AttributeError): @@ -230,7 +238,6 @@ def handle_long_seq_in_dataset( if filter_map_kwargs: drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})" - excess_length_strategy = (cfg.excess_length_strategy or "drop").lower() if excess_length_strategy == "truncate": process_fn = functools.partial( truncate_long_seq, @@ -259,4 +266,4 @@ def handle_long_seq_in_dataset( ) LOG.warning(f"{action.title()} {dropped} samples from dataset") - return dataset + return dataset \ No newline at end of file diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index d97577d86..6be728e66 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -201,16 +201,33 @@ def add_pose_position_ids( def add_length(sample): + """ + Set the "length" field on a sample to the number of input tokens. + + Parameters: + sample (Mapping-like): A sample containing an "input_ids" sequence. + + Returns: + sample (dict-like): The same sample with "length" set to len(sample["input_ids"]). + """ sample["length"] = len(sample["input_ids"]) return sample -def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): +def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False): """ - Drop samples whose sequence length is either too long (> sequence_len) - or too short (< min_sequence_len). - - Works for both single-example (list[int]) or batched (list[list[int]]). + Return whether a sample (single or batched) should be kept based on sequence length constraints. + + Determines if each sequence's length falls within [min_sequence_len, sequence_len]. Supports a single example (list[int]) or a batch (list[list[int]]). If the sample's "input_ids" is empty, the sample is treated as dropped. When raise_on_drop is True, encountering any sequence longer than sequence_len raises a ValueError. + + Parameters: + sample (dict): A mapping containing "input_ids" with either a single sequence or a batch of sequences. + sequence_len (int): Maximum allowed sequence length (inclusive). + min_sequence_len (int): Minimum allowed sequence length (inclusive). + raise_on_drop (bool): If True, raise ValueError when a sequence exceeds sequence_len. + + Returns: + bool or list[bool]: For a single example, returns True if its length is within the bounds, False otherwise. For a batch, returns a list of booleans indicating which sequences should be kept. """ min_sequence_len = min_sequence_len or 2 @@ -225,12 +242,20 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): if isinstance(input_ids[0], int): # Single example (input_ids is a list of int) length = len(input_ids) + if raise_on_drop and length > sequence_len: + raise ValueError( + f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}." + ) return min_sequence_len <= length <= sequence_len # Batched (input_ids is a list of lists) results = [] for seq in input_ids: length = len(seq) + if raise_on_drop and length > sequence_len: + raise ValueError( + f"Sequence encountered with {length} tokens, which exceeds the maximum {sequence_len}." + ) results.append(min_sequence_len <= length <= sequence_len) return results @@ -715,4 +740,4 @@ def setup_trainer( trainer_builder.train_dataset = train_dataset trainer_builder.eval_dataset = eval_dataset - return trainer_builder.build(total_num_steps) + return trainer_builder.build(total_num_steps) \ No newline at end of file