From 9f68918f130417c8f721da22fa49e1769d686fd0 Mon Sep 17 00:00:00 2001 From: mhenrhcsen Date: Mon, 12 May 2025 14:08:43 +0200 Subject: [PATCH] Implement configurable handling of excess tokens in datasets - Added `excess_token_handling` option to the configuration, allowing users to choose between "drop" and "truncate" for handling tokens exceeding the maximum sequence length. - Introduced `truncate_or_drop_long_seq` function to manage both single and batched samples based on the selected handling method. - Updated relevant dataset processing functions to utilize the new handling option, ensuring backward compatibility with existing "drop" behavior. - Enhanced logging to reflect truncation actions in dataset processing. This change improves flexibility in managing sequence lengths during training and evaluation. --- docs/config.qmd | 2 + src/axolotl/utils/data/pretraining.py | 2 + src/axolotl/utils/data/rl.py | 95 ++++++++++++++++---- src/axolotl/utils/data/utils.py | 61 +++++++++---- src/axolotl/utils/schemas/config.py | 4 + src/axolotl/utils/trainer.py | 124 ++++++++++++++++++++++++-- 6 files changed, 247 insertions(+), 41 deletions(-) diff --git a/docs/config.qmd b/docs/config.qmd index 1cff9e6f4..227797fec 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -332,6 +332,8 @@ dataset_shard_idx: # The maximum length of an input to train with, this should typically be less than 2048 # as most models have a token/context limit of 2048 sequence_len: 2048 +# How to handle tokens exceeding max sequence length - "drop" (default, removes sample) or "truncate" (cuts off excess tokens) +excess_token_handling: drop # Pad inputs so each step uses constant sized buffers # This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently pad_to_sequence_len: diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py index f20ced221..8fc01142f 100644 --- a/src/axolotl/utils/data/pretraining.py +++ b/src/axolotl/utils/data/pretraining.py @@ -259,6 +259,8 @@ def encode_packed_pretraining( # FIXME using attention mask unpad/pad with trainer and packed pretraining is broken atm # workaround by using the position id logic for now in trainer drop_attention_mask=multipack_attn, + # pass through handling mode from config via ds_wrapper function + handling=getattr(ds_wrapper, "cfg", {}).get("excess_token_handling", "drop"), ) sampler = MultipackBatchSampler( diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 135de61a3..052dff43a 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -78,7 +78,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs): def drop_long_rl_seq( - sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name + sample, rl, tokenizer, sequence_len, handling="drop" # pylint: disable=invalid-name ): if rl in ("dpo", "ipo", "orpo", "simpo"): if not ( @@ -96,9 +96,39 @@ def drop_long_rl_seq( len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"]) len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"]) - return (len_prompt + len_chosen) <= sequence_len and ( - len_prompt + len_rejected - ) <= sequence_len + if handling == "drop": + return (len_prompt + len_chosen) <= sequence_len and ( + len_prompt + len_rejected + ) <= sequence_len + else: # truncate + # If both sequences fit, return sample unchanged + if (len_prompt + len_chosen) <= sequence_len and ( + len_prompt + len_rejected + ) <= sequence_len: + return sample + + # For truncation, we need to truncate the chosen and rejected responses + # to fit within sequence_len, but preserve the prompt + + # Calculate maximum response length that can fit with the prompt + max_response_len = sequence_len - len_prompt + + if max_response_len <= 0: + # Prompt is already too long, we can't truncate effectively + return False if handling == "drop" else sample + + # Truncate the chosen and rejected responses if needed + if len_chosen > max_response_len: + # Tokenize, truncate, and decode + chosen_tokens = tokenizer(chosen, add_special_tokens=False)["input_ids"][:max_response_len] + sample["chosen"] = tokenizer.decode(chosen_tokens, skip_special_tokens=True) + + if len_rejected > max_response_len: + # Tokenize, truncate, and decode + rejected_tokens = tokenizer(rejected, add_special_tokens=False)["input_ids"][:max_response_len] + sample["rejected"] = tokenizer.decode(rejected_tokens, skip_special_tokens=True) + + return sample if rl == "kto": if not (sample.get("prompt") and sample.get("completion")): @@ -112,10 +142,30 @@ def drop_long_rl_seq( tokenizer(completion, add_special_tokens=False)["input_ids"] ) - return (len_prompt + len_completion) <= sequence_len + if handling == "drop": + return (len_prompt + len_completion) <= sequence_len + else: # truncate + # If sequence fits, return sample unchanged + if (len_prompt + len_completion) <= sequence_len: + return sample + + # Calculate maximum completion length that can fit with the prompt + max_completion_len = sequence_len - len_prompt + + if max_completion_len <= 0: + # Prompt is already too long, we can't truncate effectively + return False if handling == "drop" else sample + + # Truncate the completion if needed + if len_completion > max_completion_len: + # Tokenize, truncate, and decode + completion_tokens = tokenizer(completion, add_special_tokens=False)["input_ids"][:max_completion_len] + sample["completion"] = tokenizer.decode(completion_tokens, skip_special_tokens=True) + + return sample if rl == "grpo": - return True + return True if handling == "drop" else sample raise ValueError("Unknown RL type") @@ -169,20 +219,33 @@ def load_prepare_preference_datasets(cfg): rl=_cfg.rl, tokenizer=tokenizer, sequence_len=cfg.sequence_len, + handling=cfg.get("excess_token_handling", "drop"), ) prior_len = len(split_datasets[i]) - split_datasets[i] = split_datasets[i].filter( - drop_long, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Long Sequences", - ) - dropped = prior_len - len(split_datasets[i]) - if dropped: - LOG.warning( - f"Dropped {dropped} long samples from dataset index {i}" + + # Use filter for drop mode and map for truncate mode + handling = cfg.get("excess_token_handling", "drop") + if handling == "drop": + split_datasets[i] = split_datasets[i].filter( + drop_long, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Dropping Long Sequences", ) + dropped = prior_len - len(split_datasets[i]) + if dropped: + LOG.warning( + f"Dropped {dropped} long samples from dataset index {i}" + ) + else: + split_datasets[i] = split_datasets[i].map( + drop_long, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Truncating Long Sequences", + ) + LOG.info(f"Truncated long samples in dataset index {i} to {cfg.sequence_len} tokens") combined_datasets = concatenate_datasets(split_datasets) combined_datasets = combined_datasets.shuffle(seed=cfg.seed) diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index a8e19582e..8a07af51e 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -13,7 +13,7 @@ from datasets import Dataset, IterableDataset from axolotl.utils.dict import DictDefault from axolotl.utils.samplers.utils import get_dataset_lengths -from axolotl.utils.trainer import drop_long_seq +from axolotl.utils.trainer import drop_long_seq, truncate_or_drop_long_seq LOG = logging.getLogger(__name__) @@ -165,11 +165,24 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): ) return dataset - drop_long = functools.partial( - drop_long_seq, - sequence_len=cfg.sequence_len, - min_sequence_len=cfg.min_sample_len, - ) + # Get the handling method from config, default to "drop" for backward compatibility + handling = cfg.get("excess_token_handling", "drop") + + if handling == "drop": + # Use the existing drop_long_seq function for backward compatibility + seq_handler = functools.partial( + drop_long_seq, + sequence_len=cfg.sequence_len, + min_sequence_len=cfg.min_sample_len, + ) + else: # handling == "truncate" + # Use the new function with truncate mode + seq_handler = functools.partial( + truncate_or_drop_long_seq, + sequence_len=cfg.sequence_len, + min_sequence_len=cfg.min_sample_len, + handling=handling, + ) try: ds_lengths = get_dataset_lengths(dataset, from_arrow=True) @@ -193,17 +206,31 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault): drop_long_kwargs = {} if filter_map_kwargs: - drop_long_kwargs["desc"] = "Dropping Long Sequences" + if handling == "drop": + drop_long_kwargs["desc"] = "Dropping Long Sequences" + else: + drop_long_kwargs["desc"] = "Truncating Long Sequences" - dataset = dataset.filter( - drop_long, - batched=True, - **filter_map_kwargs, - **drop_long_kwargs, - ) - if prior_len: - dropped = prior_len - len(dataset) - if dropped: - LOG.warning(f"Dropped {dropped} long samples from dataset") + if handling == "drop": + # Use filter for drop mode + dataset = dataset.filter( + seq_handler, + batched=True, + **filter_map_kwargs, + **drop_long_kwargs, + ) + if prior_len: + dropped = prior_len - len(dataset) + if dropped: + LOG.warning(f"Dropped {dropped} long samples from dataset") + else: + # Use map for truncate mode + dataset = dataset.map( + seq_handler, + batched=True, + **filter_map_kwargs, + **drop_long_kwargs, + ) + LOG.info(f"Truncated long samples in dataset to {cfg.sequence_len} tokens") return dataset diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 9db374409..843a659c9 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -186,6 +186,10 @@ class AxolotlInputConfig( unfrozen_parameters: list[str] | None = None sequence_len: int = Field(default=512) + excess_token_handling: Literal["drop", "truncate"] = Field( + default="drop", + json_schema_extra={"description": "how to handle tokens exceeding max sequence length - drop the sample or truncate"}, + ) min_sample_len: int | None = None max_prompt_len: int = Field( default=512, diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 96f54b39d..9e1576663 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -235,6 +235,104 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): return results +def truncate_or_drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, handling="drop"): + """ + Either drop or truncate samples whose sequence length is either too long (> sequence_len) + or too short (< min_sequence_len). + + If handling is "drop": + - Samples that are too short or too long will be dropped + If handling is "truncate": + - Samples that are too short will still be dropped + - Samples that are too long will be truncated to sequence_len + + Works for both single-example (list[int]) or batched (list[list[int]]). + Returns either a boolean/list of booleans (for drop mode) or the modified sample (for truncate mode). + """ + min_sequence_len = min_sequence_len or 2 + + if handling == "drop": + return drop_long_seq(sample, sequence_len, min_sequence_len) + + input_ids = sample["input_ids"] + + # Edge case: if input_ids is empty + if not input_ids: + return False if handling == "drop" else sample + + # Check if single example or batched by looking at the first element + if isinstance(input_ids[0], int): + # Single example (input_ids is a list of int) + length = len(input_ids) + + # Handle samples that are too short - always drop them + if length < min_sequence_len: + return False if handling == "drop" else sample + + # If truncation is enabled and the sample is too long, truncate it + if length > sequence_len and handling == "truncate": + sample["input_ids"] = input_ids[:sequence_len] + + # Also truncate attention_mask if present + if "attention_mask" in sample: + sample["attention_mask"] = sample["attention_mask"][:sequence_len] + + # Also truncate labels if present + if "labels" in sample: + sample["labels"] = sample["labels"][:sequence_len] + + # Also truncate position_ids if present + if "position_ids" in sample: + sample["position_ids"] = sample["position_ids"][:sequence_len] + + # Update length if present + if "length" in sample: + sample["length"] = sequence_len + + return sample + + # For drop mode or if the sample doesn't exceed max length + return min_sequence_len <= length <= sequence_len if handling == "drop" else sample + + # Batched (input_ids is a list of lists) + if handling == "drop": + results = [] + for seq in input_ids: + length = len(seq) + results.append(min_sequence_len <= length <= sequence_len) + return results + else: # truncate + # Check each sequence in the batch + for i, seq in enumerate(input_ids): + length = len(seq) + + # Skip sequences that are too short + if length < min_sequence_len: + continue + + # Truncate sequences that are too long + if length > sequence_len: + input_ids[i] = seq[:sequence_len] + + # Also truncate attention_mask if present + if "attention_mask" in sample: + sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len] + + # Also truncate labels if present + if "labels" in sample: + sample["labels"][i] = sample["labels"][i][:sequence_len] + + # Also truncate position_ids if present + if "position_ids" in sample: + sample["position_ids"][i] = sample["position_ids"][i][:sequence_len] + + # Update length if present + if "length" in sample: + sample["length"][i] = sequence_len + + return sample + + def process_datasets_for_packing(cfg, train_dataset, eval_dataset): drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"] if drop_attn_mask: @@ -370,15 +468,25 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): def process_pretraining_datasets_for_packing( - train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False + train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False, handling="drop" ): - drop_long = partial(drop_long_seq, sequence_len=sequence_len) - - train_dataset = train_dataset.filter( - drop_long, - desc="Dropping Long Sequences", - load_from_cache_file=False, - ) + drop_long_fn = partial(drop_long_seq, sequence_len=sequence_len) + + # Use filter for drop mode and map for truncate mode + if handling == "drop": + train_dataset = train_dataset.filter( + drop_long_fn, + desc="Dropping Long Sequences", + load_from_cache_file=False, + ) + else: + truncate_fn = partial(truncate_or_drop_long_seq, sequence_len=sequence_len, handling=handling) + train_dataset = train_dataset.map( + truncate_fn, + desc="Truncating Long Sequences", + load_from_cache_file=False, + ) + if not skip_position_ids: train_dataset = train_dataset.map( add_position_ids,