Implement configurable handling of excess tokens in datasets
- Added `excess_token_handling` option to the configuration, allowing users to choose between "drop" and "truncate" for handling tokens exceeding the maximum sequence length. - Introduced `truncate_or_drop_long_seq` function to manage both single and batched samples based on the selected handling method. - Updated relevant dataset processing functions to utilize the new handling option, ensuring backward compatibility with existing "drop" behavior. - Enhanced logging to reflect truncation actions in dataset processing. This change improves flexibility in managing sequence lengths during training and evaluation.
This commit is contained in:
@@ -332,6 +332,8 @@ dataset_shard_idx:
|
|||||||
# The maximum length of an input to train with, this should typically be less than 2048
|
# The maximum length of an input to train with, this should typically be less than 2048
|
||||||
# as most models have a token/context limit of 2048
|
# as most models have a token/context limit of 2048
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
|
# How to handle tokens exceeding max sequence length - "drop" (default, removes sample) or "truncate" (cuts off excess tokens)
|
||||||
|
excess_token_handling: drop
|
||||||
# Pad inputs so each step uses constant sized buffers
|
# Pad inputs so each step uses constant sized buffers
|
||||||
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
||||||
pad_to_sequence_len:
|
pad_to_sequence_len:
|
||||||
|
|||||||
@@ -259,6 +259,8 @@ def encode_packed_pretraining(
|
|||||||
# FIXME using attention mask unpad/pad with trainer and packed pretraining is broken atm
|
# FIXME using attention mask unpad/pad with trainer and packed pretraining is broken atm
|
||||||
# workaround by using the position id logic for now in trainer
|
# workaround by using the position id logic for now in trainer
|
||||||
drop_attention_mask=multipack_attn,
|
drop_attention_mask=multipack_attn,
|
||||||
|
# pass through handling mode from config via ds_wrapper function
|
||||||
|
handling=getattr(ds_wrapper, "cfg", {}).get("excess_token_handling", "drop"),
|
||||||
)
|
)
|
||||||
|
|
||||||
sampler = MultipackBatchSampler(
|
sampler = MultipackBatchSampler(
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def drop_long_rl_seq(
|
def drop_long_rl_seq(
|
||||||
sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name
|
sample, rl, tokenizer, sequence_len, handling="drop" # pylint: disable=invalid-name
|
||||||
):
|
):
|
||||||
if rl in ("dpo", "ipo", "orpo", "simpo"):
|
if rl in ("dpo", "ipo", "orpo", "simpo"):
|
||||||
if not (
|
if not (
|
||||||
@@ -96,9 +96,39 @@ def drop_long_rl_seq(
|
|||||||
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
|
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
|
||||||
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
||||||
|
|
||||||
return (len_prompt + len_chosen) <= sequence_len and (
|
if handling == "drop":
|
||||||
len_prompt + len_rejected
|
return (len_prompt + len_chosen) <= sequence_len and (
|
||||||
) <= sequence_len
|
len_prompt + len_rejected
|
||||||
|
) <= sequence_len
|
||||||
|
else: # truncate
|
||||||
|
# If both sequences fit, return sample unchanged
|
||||||
|
if (len_prompt + len_chosen) <= sequence_len and (
|
||||||
|
len_prompt + len_rejected
|
||||||
|
) <= sequence_len:
|
||||||
|
return sample
|
||||||
|
|
||||||
|
# For truncation, we need to truncate the chosen and rejected responses
|
||||||
|
# to fit within sequence_len, but preserve the prompt
|
||||||
|
|
||||||
|
# Calculate maximum response length that can fit with the prompt
|
||||||
|
max_response_len = sequence_len - len_prompt
|
||||||
|
|
||||||
|
if max_response_len <= 0:
|
||||||
|
# Prompt is already too long, we can't truncate effectively
|
||||||
|
return False if handling == "drop" else sample
|
||||||
|
|
||||||
|
# Truncate the chosen and rejected responses if needed
|
||||||
|
if len_chosen > max_response_len:
|
||||||
|
# Tokenize, truncate, and decode
|
||||||
|
chosen_tokens = tokenizer(chosen, add_special_tokens=False)["input_ids"][:max_response_len]
|
||||||
|
sample["chosen"] = tokenizer.decode(chosen_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
if len_rejected > max_response_len:
|
||||||
|
# Tokenize, truncate, and decode
|
||||||
|
rejected_tokens = tokenizer(rejected, add_special_tokens=False)["input_ids"][:max_response_len]
|
||||||
|
sample["rejected"] = tokenizer.decode(rejected_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
if rl == "kto":
|
if rl == "kto":
|
||||||
if not (sample.get("prompt") and sample.get("completion")):
|
if not (sample.get("prompt") and sample.get("completion")):
|
||||||
@@ -112,10 +142,30 @@ def drop_long_rl_seq(
|
|||||||
tokenizer(completion, add_special_tokens=False)["input_ids"]
|
tokenizer(completion, add_special_tokens=False)["input_ids"]
|
||||||
)
|
)
|
||||||
|
|
||||||
return (len_prompt + len_completion) <= sequence_len
|
if handling == "drop":
|
||||||
|
return (len_prompt + len_completion) <= sequence_len
|
||||||
|
else: # truncate
|
||||||
|
# If sequence fits, return sample unchanged
|
||||||
|
if (len_prompt + len_completion) <= sequence_len:
|
||||||
|
return sample
|
||||||
|
|
||||||
|
# Calculate maximum completion length that can fit with the prompt
|
||||||
|
max_completion_len = sequence_len - len_prompt
|
||||||
|
|
||||||
|
if max_completion_len <= 0:
|
||||||
|
# Prompt is already too long, we can't truncate effectively
|
||||||
|
return False if handling == "drop" else sample
|
||||||
|
|
||||||
|
# Truncate the completion if needed
|
||||||
|
if len_completion > max_completion_len:
|
||||||
|
# Tokenize, truncate, and decode
|
||||||
|
completion_tokens = tokenizer(completion, add_special_tokens=False)["input_ids"][:max_completion_len]
|
||||||
|
sample["completion"] = tokenizer.decode(completion_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
if rl == "grpo":
|
if rl == "grpo":
|
||||||
return True
|
return True if handling == "drop" else sample
|
||||||
|
|
||||||
raise ValueError("Unknown RL type")
|
raise ValueError("Unknown RL type")
|
||||||
|
|
||||||
@@ -169,20 +219,33 @@ def load_prepare_preference_datasets(cfg):
|
|||||||
rl=_cfg.rl,
|
rl=_cfg.rl,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
sequence_len=cfg.sequence_len,
|
sequence_len=cfg.sequence_len,
|
||||||
|
handling=cfg.get("excess_token_handling", "drop"),
|
||||||
)
|
)
|
||||||
|
|
||||||
prior_len = len(split_datasets[i])
|
prior_len = len(split_datasets[i])
|
||||||
split_datasets[i] = split_datasets[i].filter(
|
|
||||||
drop_long,
|
# Use filter for drop mode and map for truncate mode
|
||||||
num_proc=cfg.dataset_processes,
|
handling = cfg.get("excess_token_handling", "drop")
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
if handling == "drop":
|
||||||
desc="Dropping Long Sequences",
|
split_datasets[i] = split_datasets[i].filter(
|
||||||
)
|
drop_long,
|
||||||
dropped = prior_len - len(split_datasets[i])
|
num_proc=cfg.dataset_processes,
|
||||||
if dropped:
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
LOG.warning(
|
desc="Dropping Long Sequences",
|
||||||
f"Dropped {dropped} long samples from dataset index {i}"
|
|
||||||
)
|
)
|
||||||
|
dropped = prior_len - len(split_datasets[i])
|
||||||
|
if dropped:
|
||||||
|
LOG.warning(
|
||||||
|
f"Dropped {dropped} long samples from dataset index {i}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
split_datasets[i] = split_datasets[i].map(
|
||||||
|
drop_long,
|
||||||
|
num_proc=cfg.dataset_processes,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
|
desc="Truncating Long Sequences",
|
||||||
|
)
|
||||||
|
LOG.info(f"Truncated long samples in dataset index {i} to {cfg.sequence_len} tokens")
|
||||||
|
|
||||||
combined_datasets = concatenate_datasets(split_datasets)
|
combined_datasets = concatenate_datasets(split_datasets)
|
||||||
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
|
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from datasets import Dataset, IterableDataset
|
|||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.samplers.utils import get_dataset_lengths
|
from axolotl.utils.samplers.utils import get_dataset_lengths
|
||||||
from axolotl.utils.trainer import drop_long_seq
|
from axolotl.utils.trainer import drop_long_seq, truncate_or_drop_long_seq
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -165,11 +165,24 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
|||||||
)
|
)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
drop_long = functools.partial(
|
# Get the handling method from config, default to "drop" for backward compatibility
|
||||||
drop_long_seq,
|
handling = cfg.get("excess_token_handling", "drop")
|
||||||
sequence_len=cfg.sequence_len,
|
|
||||||
min_sequence_len=cfg.min_sample_len,
|
if handling == "drop":
|
||||||
)
|
# Use the existing drop_long_seq function for backward compatibility
|
||||||
|
seq_handler = functools.partial(
|
||||||
|
drop_long_seq,
|
||||||
|
sequence_len=cfg.sequence_len,
|
||||||
|
min_sequence_len=cfg.min_sample_len,
|
||||||
|
)
|
||||||
|
else: # handling == "truncate"
|
||||||
|
# Use the new function with truncate mode
|
||||||
|
seq_handler = functools.partial(
|
||||||
|
truncate_or_drop_long_seq,
|
||||||
|
sequence_len=cfg.sequence_len,
|
||||||
|
min_sequence_len=cfg.min_sample_len,
|
||||||
|
handling=handling,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
ds_lengths = get_dataset_lengths(dataset, from_arrow=True)
|
||||||
@@ -193,17 +206,31 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
|
|||||||
|
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
if handling == "drop":
|
||||||
|
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
||||||
|
else:
|
||||||
|
drop_long_kwargs["desc"] = "Truncating Long Sequences"
|
||||||
|
|
||||||
dataset = dataset.filter(
|
if handling == "drop":
|
||||||
drop_long,
|
# Use filter for drop mode
|
||||||
batched=True,
|
dataset = dataset.filter(
|
||||||
**filter_map_kwargs,
|
seq_handler,
|
||||||
**drop_long_kwargs,
|
batched=True,
|
||||||
)
|
**filter_map_kwargs,
|
||||||
if prior_len:
|
**drop_long_kwargs,
|
||||||
dropped = prior_len - len(dataset)
|
)
|
||||||
if dropped:
|
if prior_len:
|
||||||
LOG.warning(f"Dropped {dropped} long samples from dataset")
|
dropped = prior_len - len(dataset)
|
||||||
|
if dropped:
|
||||||
|
LOG.warning(f"Dropped {dropped} long samples from dataset")
|
||||||
|
else:
|
||||||
|
# Use map for truncate mode
|
||||||
|
dataset = dataset.map(
|
||||||
|
seq_handler,
|
||||||
|
batched=True,
|
||||||
|
**filter_map_kwargs,
|
||||||
|
**drop_long_kwargs,
|
||||||
|
)
|
||||||
|
LOG.info(f"Truncated long samples in dataset to {cfg.sequence_len} tokens")
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|||||||
@@ -186,6 +186,10 @@ class AxolotlInputConfig(
|
|||||||
unfrozen_parameters: list[str] | None = None
|
unfrozen_parameters: list[str] | None = None
|
||||||
|
|
||||||
sequence_len: int = Field(default=512)
|
sequence_len: int = Field(default=512)
|
||||||
|
excess_token_handling: Literal["drop", "truncate"] = Field(
|
||||||
|
default="drop",
|
||||||
|
json_schema_extra={"description": "how to handle tokens exceeding max sequence length - drop the sample or truncate"},
|
||||||
|
)
|
||||||
min_sample_len: int | None = None
|
min_sample_len: int | None = None
|
||||||
max_prompt_len: int = Field(
|
max_prompt_len: int = Field(
|
||||||
default=512,
|
default=512,
|
||||||
|
|||||||
@@ -235,6 +235,104 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_or_drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, handling="drop"):
|
||||||
|
"""
|
||||||
|
Either drop or truncate samples whose sequence length is either too long (> sequence_len)
|
||||||
|
or too short (< min_sequence_len).
|
||||||
|
|
||||||
|
If handling is "drop":
|
||||||
|
- Samples that are too short or too long will be dropped
|
||||||
|
If handling is "truncate":
|
||||||
|
- Samples that are too short will still be dropped
|
||||||
|
- Samples that are too long will be truncated to sequence_len
|
||||||
|
|
||||||
|
Works for both single-example (list[int]) or batched (list[list[int]]).
|
||||||
|
Returns either a boolean/list of booleans (for drop mode) or the modified sample (for truncate mode).
|
||||||
|
"""
|
||||||
|
min_sequence_len = min_sequence_len or 2
|
||||||
|
|
||||||
|
if handling == "drop":
|
||||||
|
return drop_long_seq(sample, sequence_len, min_sequence_len)
|
||||||
|
|
||||||
|
input_ids = sample["input_ids"]
|
||||||
|
|
||||||
|
# Edge case: if input_ids is empty
|
||||||
|
if not input_ids:
|
||||||
|
return False if handling == "drop" else sample
|
||||||
|
|
||||||
|
# Check if single example or batched by looking at the first element
|
||||||
|
if isinstance(input_ids[0], int):
|
||||||
|
# Single example (input_ids is a list of int)
|
||||||
|
length = len(input_ids)
|
||||||
|
|
||||||
|
# Handle samples that are too short - always drop them
|
||||||
|
if length < min_sequence_len:
|
||||||
|
return False if handling == "drop" else sample
|
||||||
|
|
||||||
|
# If truncation is enabled and the sample is too long, truncate it
|
||||||
|
if length > sequence_len and handling == "truncate":
|
||||||
|
sample["input_ids"] = input_ids[:sequence_len]
|
||||||
|
|
||||||
|
# Also truncate attention_mask if present
|
||||||
|
if "attention_mask" in sample:
|
||||||
|
sample["attention_mask"] = sample["attention_mask"][:sequence_len]
|
||||||
|
|
||||||
|
# Also truncate labels if present
|
||||||
|
if "labels" in sample:
|
||||||
|
sample["labels"] = sample["labels"][:sequence_len]
|
||||||
|
|
||||||
|
# Also truncate position_ids if present
|
||||||
|
if "position_ids" in sample:
|
||||||
|
sample["position_ids"] = sample["position_ids"][:sequence_len]
|
||||||
|
|
||||||
|
# Update length if present
|
||||||
|
if "length" in sample:
|
||||||
|
sample["length"] = sequence_len
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
# For drop mode or if the sample doesn't exceed max length
|
||||||
|
return min_sequence_len <= length <= sequence_len if handling == "drop" else sample
|
||||||
|
|
||||||
|
# Batched (input_ids is a list of lists)
|
||||||
|
if handling == "drop":
|
||||||
|
results = []
|
||||||
|
for seq in input_ids:
|
||||||
|
length = len(seq)
|
||||||
|
results.append(min_sequence_len <= length <= sequence_len)
|
||||||
|
return results
|
||||||
|
else: # truncate
|
||||||
|
# Check each sequence in the batch
|
||||||
|
for i, seq in enumerate(input_ids):
|
||||||
|
length = len(seq)
|
||||||
|
|
||||||
|
# Skip sequences that are too short
|
||||||
|
if length < min_sequence_len:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Truncate sequences that are too long
|
||||||
|
if length > sequence_len:
|
||||||
|
input_ids[i] = seq[:sequence_len]
|
||||||
|
|
||||||
|
# Also truncate attention_mask if present
|
||||||
|
if "attention_mask" in sample:
|
||||||
|
sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len]
|
||||||
|
|
||||||
|
# Also truncate labels if present
|
||||||
|
if "labels" in sample:
|
||||||
|
sample["labels"][i] = sample["labels"][i][:sequence_len]
|
||||||
|
|
||||||
|
# Also truncate position_ids if present
|
||||||
|
if "position_ids" in sample:
|
||||||
|
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len]
|
||||||
|
|
||||||
|
# Update length if present
|
||||||
|
if "length" in sample:
|
||||||
|
sample["length"][i] = sequence_len
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||||
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"]
|
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"]
|
||||||
if drop_attn_mask:
|
if drop_attn_mask:
|
||||||
@@ -370,15 +468,25 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
|
|
||||||
|
|
||||||
def process_pretraining_datasets_for_packing(
|
def process_pretraining_datasets_for_packing(
|
||||||
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False
|
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False, handling="drop"
|
||||||
):
|
):
|
||||||
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
|
drop_long_fn = partial(drop_long_seq, sequence_len=sequence_len)
|
||||||
|
|
||||||
|
# Use filter for drop mode and map for truncate mode
|
||||||
|
if handling == "drop":
|
||||||
|
train_dataset = train_dataset.filter(
|
||||||
|
drop_long_fn,
|
||||||
|
desc="Dropping Long Sequences",
|
||||||
|
load_from_cache_file=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
truncate_fn = partial(truncate_or_drop_long_seq, sequence_len=sequence_len, handling=handling)
|
||||||
|
train_dataset = train_dataset.map(
|
||||||
|
truncate_fn,
|
||||||
|
desc="Truncating Long Sequences",
|
||||||
|
load_from_cache_file=False,
|
||||||
|
)
|
||||||
|
|
||||||
train_dataset = train_dataset.filter(
|
|
||||||
drop_long,
|
|
||||||
desc="Dropping Long Sequences",
|
|
||||||
load_from_cache_file=False,
|
|
||||||
)
|
|
||||||
if not skip_position_ids:
|
if not skip_position_ids:
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
|
|||||||
Reference in New Issue
Block a user