diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 975f26e71..2ae7d9052 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -28,7 +28,7 @@ from axolotl.utils.data.shared import ( ) from axolotl.utils.data.utils import ( deduplicate_and_log_datasets, - drop_long_seq_in_dataset, + handle_long_seq_in_dataset, retry_on_request_exceptions, ) from axolotl.utils.data.wrappers import get_dataset_wrapper @@ -339,9 +339,9 @@ def _load_raw_datasets( if not cfg.skip_prepare_dataset: if split == "test" and cfg.eval_sequence_len: - dataset = drop_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg) + dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg) else: - dataset = drop_long_seq_in_dataset(dataset, cfg.sequence_len, cfg) + dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg) if cfg.sample_packing: dataset, _ = process_datasets_for_packing(cfg, dataset, None) diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index c0efb7a42..856a609c7 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -148,7 +148,36 @@ def deduplicate_and_log_datasets( return dataset, other_dataset -def drop_long_seq_in_dataset( +def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2): + """ + Truncate samples whose sequence length is too long (> sequence_len) + or drop those too short (< min_sequence_len). + """ + min_sequence_len = min_sequence_len or 2 + + input_ids = sample["input_ids"] + results = [] + + # Batched (input_ids is a list of lists) + for i, seq in enumerate(input_ids): + length = len(seq) + if length < min_sequence_len: + results.append(False) + elif length > sequence_len: + sample["input_ids"][i] = seq[:sequence_len] + if "attention_mask" in sample: + sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len] + if "labels" in sample: + sample["labels"][i] = sample["labels"][i][:sequence_len] + if "position_ids" in sample: + sample["position_ids"][i] = sample["position_ids"][i][:sequence_len] + results.append(True) + else: + results.append(True) + return results + + +def handle_long_seq_in_dataset( dataset: Dataset, sequence_len: int, cfg: DictDefault ) -> Dataset: """Remove sequences longer than configured maximum from dataset. @@ -192,8 +221,21 @@ def drop_long_seq_in_dataset( if filter_map_kwargs: drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})" + excess_length_strategy = (cfg.excess_length_strategy or "drop").lower() + if excess_length_strategy == "truncate": + process_fn = functools.partial( + truncate_long_seq, + sequence_len=sequence_len, + min_sequence_len=cfg.min_sample_len, + ) + drop_long_kwargs["desc"] = ( + f"Truncating/Filtering Sequences (target_len={sequence_len})" + ) + else: + process_fn = drop_long + dataset = dataset.filter( - drop_long, + process_fn, batched=True, **filter_map_kwargs, **drop_long_kwargs, @@ -201,6 +243,11 @@ def drop_long_seq_in_dataset( if prior_len: dropped = prior_len - len(dataset) if dropped: - LOG.warning(f"Dropped {dropped} long samples from dataset") + action = ( + "truncated/filtered" + if excess_length_strategy == "truncate" + else "dropped" + ) + LOG.warning(f"{action.title()} {dropped} samples from dataset") return dataset diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 21e99c048..a607b3dca 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -414,6 +414,12 @@ class AxolotlInputConfig( "description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048" }, ) + excess_length_strategy: Literal["drop", "truncate"] | None = Field( + default=None, + json_schema_extra={ + "description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility." + }, + ) eval_sequence_len: int | None = Field( default=None, json_schema_extra={ diff --git a/tests/test_packed_batch_sampler.py b/tests/test_packed_batch_sampler.py index 7cb645db7..47894a35b 100644 --- a/tests/test_packed_batch_sampler.py +++ b/tests/test_packed_batch_sampler.py @@ -8,7 +8,7 @@ from transformers import AutoTokenizer from axolotl.datasets import TokenizedPromptDataset from axolotl.prompt_strategies.completion import load from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq -from axolotl.utils.data.utils import drop_long_seq_in_dataset +from axolotl.utils.data.utils import handle_long_seq_in_dataset from axolotl.utils.dict import DictDefault from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths @@ -70,7 +70,7 @@ class TestBatchedSamplerPacking: ) train_dataset = concatenate_datasets([dataset_wrapper]) - train_dataset = drop_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg) + train_dataset = handle_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg) lengths = get_dataset_lengths(train_dataset) batch_sampler = MultipackBatchSampler(