Compare commits

...

13 Commits

Author SHA1 Message Date
coderabbitai[bot]
e23a5c9fda 📝 Add docstrings to 775-option-to-drop-vs-truncate-on-rows-longer-than-context-length
Docstrings generation was requested by @mhenrichsen.

* https://github.com/axolotl-ai-cloud/axolotl/pull/2662#issuecomment-2883401776

The following files were modified:

* `src/axolotl/utils/data/pretraining.py`
* `src/axolotl/utils/data/rl.py`
* `src/axolotl/utils/data/utils.py`
* `src/axolotl/utils/trainer.py`
* `tests/test_data.py`
* `tests/test_trainer_utils.py`
2025-05-15 11:02:45 +00:00
mhenrhcsen
5d7a61576d Refactor sequence length overflow handling in pretraining module
- Introduced DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING constant in utils.py.
- Updated encode_packed_pretraining function to use this constant instead of a hardcoded value.
2025-05-15 12:55:09 +02:00
mhenrhcsen
5ecf22b54e Merge branch 'main' of github.com:axolotl-ai-cloud/axolotl into 775-option-to-drop-vs-truncate-on-rows-longer-than-context-length 2025-05-14 13:36:43 +02:00
mhenrhcsen
9c5b8da22f fix merge conflicts 2025-05-14 13:33:42 +02:00
mhenrhcsen
fea6649518 increased test coverage 2025-05-13 08:58:34 +02:00
mhenrhcsen
124ad2b968 lint 2025-05-13 08:35:16 +02:00
mhenrhcsen
767c2340f1 docstring for tests 2025-05-12 22:57:43 +02:00
mhenrhcsen
f6623c34cc Linting fix 2025-05-12 22:53:30 +02:00
mhenrhcsen
5dd8f0b2b8 Fixes comments from winglian 2025-05-12 22:43:15 +02:00
mhenrhcsen
be3c6bbd85 fix linting issues 2025-05-12 14:46:57 +02:00
mhenrhcsen
f07db4f853 Refactor truncation logic in drop_long_rl_seq function
- Simplified the truncation process for chosen and rejected responses to ensure they fit within the specified sequence length while preserving the prompt.
- Improved readability by restructuring the code and removing redundant checks.
- Ensured that the function returns the sample correctly after processing, maintaining compatibility with existing handling options.
2025-05-12 14:40:10 +02:00
mhenrhcsen
17a5838d38 lint 2025-05-12 14:36:43 +02:00
mhenrhcsen
9f68918f13 Implement configurable handling of excess tokens in datasets
- Added `excess_token_handling` option to the configuration, allowing users to choose between "drop" and "truncate" for handling tokens exceeding the maximum sequence length.
- Introduced `truncate_or_drop_long_seq` function to manage both single and batched samples based on the selected handling method.
- Updated relevant dataset processing functions to utilize the new handling option, ensuring backward compatibility with existing "drop" behavior.
- Enhanced logging to reflect truncation actions in dataset processing.

This change improves flexibility in managing sequence lengths during training and evaluation.
2025-05-12 14:08:43 +02:00
8 changed files with 891 additions and 44 deletions

View File

@@ -332,6 +332,8 @@ dataset_shard_idx:
# The maximum length of an input to train with, this should typically be less than 2048
# as most models have a token/context limit of 2048
sequence_len: 2048
# How to handle sequences that overflow the sequence_len: 'drop' (default, removes sample) or 'truncate' (cuts off excess tokens).
sequence_len_overflow_handling: drop
# Pad inputs so each step uses constant sized buffers
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
pad_to_sequence_len:

View File

@@ -11,6 +11,7 @@ from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
@@ -250,6 +251,22 @@ def encode_packed_pretraining(
# pylint: disable=duplicate-code
# tokenize all the examples
# rows get split with stride (overlap)
"""
Encodes and packs input examples into fixed-length batches for pretraining with optional multipack attention.
Wraps and processes input examples into a dataset, applies sequence packing with configurable overflow handling, and batches the data using a multipack sampler. Each batch is collated and features are aggregated into lists keyed by feature name.
Args:
collate_fn: Function to collate individual feature dictionaries into batch tensors.
ds_wrapper: Callable that wraps a Hugging Face Dataset for further processing.
examples: Dictionary of input examples to encode and pack.
max_seq_length: Maximum sequence length for each packed sequence.
batch_size: Number of sequences to pack per batch.
multipack_attn: If True, enables multipack attention and drops attention masks.
Returns:
Dictionary where each key is a feature name and each value is a list of packed feature tensors.
"""
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
train_dataset = process_pretraining_datasets_for_packing(
@@ -259,6 +276,10 @@ def encode_packed_pretraining(
# FIXME using attention mask unpad/pad with trainer and packed pretraining is broken atm
# workaround by using the position id logic for now in trainer
drop_attention_mask=multipack_attn,
# pass through handling mode from config via ds_wrapper function
handling=getattr(ds_wrapper, "cfg", {}).get(
"sequence_len_overflow_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
),
)
sampler = MultipackBatchSampler(

View File

@@ -79,8 +79,33 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
def drop_long_rl_seq(
sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name
sample,
rl,
tokenizer,
sequence_len,
handling="drop", # Use the default handling mode
):
"""
Handles samples exceeding a maximum sequence length for various RL dataset types by either truncating or dropping them.
Depending on the RL type and the `handling` mode, this function either truncates response fields to fit within the specified sequence length or determines whether the sample should be dropped. For DPO, IPO, ORPO, and SIMPO types, both "chosen" and "rejected" responses are considered; for KTO, the "completion" is considered. For GRPO, samples are always retained. If truncation is not possible (e.g., the prompt alone exceeds the limit), the sample is returned unchanged for mapping, or dropped during filtering.
Args:
sample: A dictionary representing a single dataset sample.
rl: The RLType indicating the dataset type.
tokenizer: The tokenizer used to compute token lengths and perform truncation.
sequence_len: The maximum allowed sequence length.
handling: Specifies how to handle overlong sequences ("drop" or "truncate").
Returns:
For "truncate": The modified sample with responses truncated as needed, or the original sample if truncation is not possible.
For "drop": True if the sample fits within the sequence length, otherwise False.
Raises:
ValueError: If required keys are missing for the specified RL type, or if the RL type is unknown.
"""
result = None
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
@@ -97,11 +122,65 @@ def drop_long_rl_seq(
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
return (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len
# Truncate first, then drop if still invalid (although truncate should handle it)
if handling == "truncate":
# If both sequences fit, return sample unchanged
if (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len:
result = sample
else:
# Calculate maximum response length that can fit with the prompt
max_response_len = sequence_len - len_prompt
if rl is RLType.KTO:
if max_response_len <= 0:
# Prompt is already too long, behavior depends on handling
# If truncate is chosen, we technically can't truncate, but drop seems harsh.
# Returning the sample might be unexpected. Let's stick to the filter logic
# which would drop this in the `filter` step later if needed.
# For now, return sample to map, or False to filter.
# Let's simplify: truncate *should* result in a valid sample if possible.
# If prompt >= seq_len, truncate won't work. Filter will catch this later.
# So, if max_response_len <= 0, we pass it through for map, drop for filter.
# However, the filter/map logic is applied *after* this function.
# This function needs to return the *modified* sample for map, or bool for filter.
# Re-think: If handling==truncate, return the modified sample if possible.
# If prompt >= seq_len, modification is impossible. What should map return?
# Maybe return the original sample? But map expects *modified* sample.
# Let's stick to the original logic: if prompt is too long, return False for filter
# and original sample for map.
result = (
sample # For map, let downstream handle it if still invalid?
)
# Or maybe return None/empty dict? Let's return sample for now.
# If handling was drop, filter would remove this.
else:
# Truncate the chosen and rejected responses if needed
if len_chosen > max_response_len:
chosen_tokens = tokenizer(chosen, add_special_tokens=False)[
"input_ids"
][:max_response_len]
sample["chosen"] = tokenizer.decode(
chosen_tokens, skip_special_tokens=True
)
if len_rejected > max_response_len:
rejected_tokens = tokenizer(rejected, add_special_tokens=False)[
"input_ids"
][:max_response_len]
sample["rejected"] = tokenizer.decode(
rejected_tokens, skip_special_tokens=True
)
result = sample
else: # handling == "drop"
result = (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len
elif rl == RLType.KTO:
if not (sample.get("prompt") and sample.get("completion")):
raise ValueError("Prompt and completion keys are required for KTO datasets")
@@ -113,15 +192,54 @@ def drop_long_rl_seq(
tokenizer(completion, add_special_tokens=False)["input_ids"]
)
return (len_prompt + len_completion) <= sequence_len
# Truncate first
if handling == "truncate":
# If sequence fits, return sample unchanged
if (len_prompt + len_completion) <= sequence_len:
result = sample
else:
# Calculate maximum completion length
max_completion_len = sequence_len - len_prompt
if rl is RLType.GRPO:
return True
if max_completion_len <= 0:
# Prompt too long, return sample for map
result = sample
else:
# Truncate the completion if needed
if len_completion > max_completion_len:
completion_tokens = tokenizer(
completion, add_special_tokens=False
)["input_ids"][:max_completion_len]
sample["completion"] = tokenizer.decode(
completion_tokens, skip_special_tokens=True
)
result = sample
else: # handling == "drop"
result = (len_prompt + len_completion) <= sequence_len
raise ValueError("Unknown RL type")
elif rl == RLType.GRPO:
# GRPO doesn't involve sequence length checks in the same way?
# The original code returned True for drop. What should it return for truncate?
# Let's assume for now it always passes.
result = sample if handling == "truncate" else True
else:
raise ValueError("Unknown RL type")
return result
def load_prepare_preference_datasets(cfg):
"""
Loads, preprocesses, and prepares preference datasets for RL training and evaluation.
This function orchestrates the loading, transformation, sequence length handling, optional deduplication, and caching of datasets for Direct Preference Optimization (DPO) and related RL types. It supports configurable handling of overlong sequences (dropping or truncating), applies dataset-specific transformations, and manages train/validation/test splits as needed.
Args:
cfg: Configuration object specifying dataset sources, RL type, tokenizer, sequence length, and processing options.
Returns:
A tuple containing the prepared training and evaluation datasets.
"""
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
use_auth_token = _cfg.hf_use_auth_token
@@ -165,25 +283,43 @@ def load_prepare_preference_datasets(cfg):
split_datasets[i] = data_set
if not cfg.skip_prepare_dataset:
# Determine handling mode
handling = cfg.get("sequence_len_overflow_handling", "drop")
drop_long = partial(
drop_long_rl_seq,
rl=_cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
handling=handling, # Pass the handling mode
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(
f"Dropped {dropped} long samples from dataset index {i}"
# Use map for truncate mode and filter for drop mode
if handling == "truncate":
split_datasets[i] = split_datasets[i].map(
drop_long, # Function now returns modified sample or original
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Truncating Long Sequences",
)
# Note: Length might not change if truncation always occurs
LOG.info(
f"Processed dataset index {i} with truncation handling for sequence length {cfg.sequence_len}"
)
else: # handling == "drop"
split_datasets[i] = split_datasets[i].filter(
drop_long, # Function now returns boolean
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(
f"Dropped {dropped} long samples from dataset index {i}"
)
combined_datasets = concatenate_datasets(split_datasets)
combined_datasets = combined_datasets.shuffle(seed=cfg.seed or 42)

View File

@@ -13,10 +13,12 @@ from datasets import Dataset, IterableDataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers.utils import get_dataset_lengths
from axolotl.utils.trainer import drop_long_seq
from axolotl.utils.trainer import truncate_or_drop_long_seq
LOG = logging.getLogger(__name__)
DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING = "drop"
class RetryStrategy(Enum):
"""
@@ -159,16 +161,33 @@ def deduplicate_and_log_datasets(
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
"""
Processes a dataset to handle sequences exceeding a configured maximum length by either truncating or dropping them.
If the dataset lacks an "input_ids" column, the function returns the dataset unchanged. The handling mode is determined by the configuration parameter "sequence_len_overflow_handling", defaulting to "drop". In "truncate" mode, sequences longer than the maximum length are truncated; in "drop" mode, such sequences are removed from the dataset. The function logs information about sequence lengths and the number of samples affected when applicable.
Args:
dataset: The Huggingface Dataset to process.
cfg: Configuration object specifying sequence length parameters and handling mode.
Returns:
The processed dataset with long sequences either truncated or dropped according to the configuration.
"""
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling."
)
return dataset
drop_long = functools.partial(
drop_long_seq,
# Get the handling method from config, default to "drop" for backward compatibility
handling = cfg.get("sequence_len_overflow_handling", "drop")
# Use the new function with the specified handling mode
seq_handler = functools.partial(
truncate_or_drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len,
handling=handling,
)
try:
@@ -193,17 +212,31 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Dropping Long Sequences"
if handling == "truncate":
drop_long_kwargs["desc"] = "Truncating Long Sequences"
else: # handling == "drop"
drop_long_kwargs["desc"] = "Dropping Long Sequences"
dataset = dataset.filter(
drop_long,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset")
if handling == "truncate":
# Use map for truncate mode
dataset = dataset.map(
seq_handler,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
LOG.info(f"Truncated long samples in dataset to {cfg.sequence_len} tokens")
else: # handling == "drop"
# Use filter for drop mode
dataset = dataset.filter(
seq_handler,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset")
return dataset

View File

@@ -186,6 +186,12 @@ class AxolotlInputConfig(
unfrozen_parameters: list[str] | None = None
sequence_len: int = Field(default=512)
sequence_len_overflow_handling: Literal["drop", "truncate"] = Field(
default="drop",
json_schema_extra={
"description": "How to handle sequences that overflow the sequence_len: 'drop' (remove the sample) or 'truncate' (cut off excess tokens)."
},
)
min_sample_len: int | None = None
max_prompt_len: int = Field(
default=512,

View File

@@ -207,10 +207,18 @@ def add_length(sample):
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
"""
Drop samples whose sequence length is either too long (> sequence_len)
or too short (< min_sequence_len).
Works for both single-example (list[int]) or batched (list[list[int]]).
Determines whether samples should be kept based on sequence length constraints.
For a single example or a batch, returns True (or a list of booleans) if each sequence's length is within the specified range; otherwise, returns False (or a list with False for out-of-range sequences).
Args:
sample: A dictionary containing "input_ids" as a list of ints or a list of lists of ints.
sequence_len: Maximum allowed sequence length (inclusive).
min_sequence_len: Minimum allowed sequence length (inclusive).
Returns:
True if the single example is within the length range, False otherwise.
For batched input, returns a list of booleans indicating which sequences are within the range.
"""
min_sequence_len = min_sequence_len or 2
@@ -235,7 +243,121 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
return results
def truncate_or_drop_long_seq(
sample, sequence_len=2048, min_sequence_len=2, handling="drop"
):
"""
Drops or truncates samples based on sequence length constraints.
If handling is "drop", returns a boolean or list of booleans indicating whether each sample's sequence length is within the specified range. If handling is "truncate", returns the sample with sequences longer than sequence_len truncated and sequences shorter than min_sequence_len omitted. Supports both single-example and batched inputs.
Args:
sample: A dictionary containing at least an "input_ids" field, representing either a single sequence or a batch of sequences.
sequence_len: Maximum allowed sequence length.
min_sequence_len: Minimum allowed sequence length.
handling: "drop" to filter out samples outside the range, "truncate" to truncate long sequences.
Returns:
In "drop" mode, a boolean or list of booleans indicating which samples to keep. In "truncate" mode, the modified sample with sequences truncated as needed.
"""
min_sequence_len = min_sequence_len or 2
result = None
if handling == "drop":
return drop_long_seq(sample, sequence_len, min_sequence_len)
input_ids = sample["input_ids"]
# Edge case: if input_ids is empty
if not input_ids:
result = False if handling == "drop" else sample
# Single example (input_ids is a list of int)
elif isinstance(input_ids[0], int):
length = len(input_ids)
# Handle samples that are too short - always drop them
if length < min_sequence_len:
result = False if handling == "drop" else sample
# If truncation is enabled and the sample is too long, truncate it
elif length > sequence_len and handling == "truncate":
sample["input_ids"] = input_ids[:sequence_len]
# Also truncate attention_mask if present
if "attention_mask" in sample:
sample["attention_mask"] = sample["attention_mask"][:sequence_len]
# Also truncate labels if present
if "labels" in sample:
sample["labels"] = sample["labels"][:sequence_len]
# Also truncate position_ids if present
if "position_ids" in sample:
sample["position_ids"] = sample["position_ids"][:sequence_len]
# Update length if present
if "length" in sample:
sample["length"] = sequence_len
result = sample
# For drop mode or if the sample doesn't exceed max length
else:
result = (
min_sequence_len <= length <= sequence_len
if handling == "drop"
else sample
)
# Batched (input_ids is a list of lists)
else:
if handling == "drop":
results = []
for seq in input_ids:
length = len(seq)
results.append(min_sequence_len <= length <= sequence_len)
result = results
else: # truncate
# Check each sequence in the batch
for i, seq in enumerate(input_ids):
length = len(seq)
# Skip sequences that are too short
if length < min_sequence_len:
continue
# Truncate sequences that are too long
if length > sequence_len:
input_ids[i] = seq[:sequence_len]
# Also truncate attention_mask if present
if "attention_mask" in sample:
sample["attention_mask"][i] = sample["attention_mask"][i][
:sequence_len
]
# Also truncate labels if present
if "labels" in sample:
sample["labels"][i] = sample["labels"][i][:sequence_len]
# Also truncate position_ids if present
if "position_ids" in sample:
sample["position_ids"][i] = sample["position_ids"][i][
:sequence_len
]
# Update length if present
if "length" in sample:
sample["length"][i] = sequence_len
result = sample
return result
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
"""
Prepares training and evaluation datasets for sample packing and model-specific requirements.
Removes unnecessary columns based on model type, filters out samples with no trainable tokens, and optionally adds length or position ID columns for sample packing or PoSE techniques. Returns the processed training and evaluation datasets.
"""
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"]
if drop_attn_mask:
LOG.info("dropping attention_mask column")
@@ -370,15 +492,48 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
def process_pretraining_datasets_for_packing(
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False
train_dataset,
sequence_len,
skip_position_ids=True,
drop_attention_mask=False,
handling="drop",
):
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
train_dataset = train_dataset.filter(
drop_long,
desc="Dropping Long Sequences",
load_from_cache_file=False,
# Define the function to use for handling sequences based on the mode
"""
Processes a pretraining dataset by truncating or dropping sequences based on length.
Depending on the handling mode, sequences longer than `sequence_len` are either truncated or dropped, and sequences shorter than `min_sequence_len` are dropped. Optionally adds position IDs and removes the attention mask column.
Args:
train_dataset: The dataset to process.
sequence_len: Maximum allowed sequence length.
skip_position_ids: If False, adds position IDs to each sample.
drop_attention_mask: If True, removes the attention mask column.
handling: "drop" to remove long sequences, "truncate" to truncate them.
Returns:
The processed dataset with sequences handled according to the specified mode.
"""
seq_handler_fn = partial(
truncate_or_drop_long_seq,
sequence_len=sequence_len,
handling=handling, # Pass handling mode
)
# Use map for truncate mode and filter for drop mode
if handling == "truncate":
train_dataset = train_dataset.map(
seq_handler_fn,
desc="Truncating Long Sequences",
load_from_cache_file=False,
)
else: # handling == "drop"
train_dataset = train_dataset.filter(
seq_handler_fn, # Use the same function, it returns boolean for drop mode
desc="Dropping Long Sequences",
load_from_cache_file=False,
)
if not skip_position_ids:
train_dataset = train_dataset.map(
add_position_ids,

View File

@@ -3,10 +3,12 @@ test module for the axolotl.utils.data module
"""
import unittest
from unittest.mock import MagicMock
from transformers import LlamaTokenizer
from axolotl.utils.data import encode_pretraining, md5
from axolotl.utils.data.rl import drop_long_rl_seq
from tests.hf_offline_utils import enable_hf_offline
@@ -58,11 +60,328 @@ class TestEncodePretraining(unittest.TestCase):
self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id)
def test_md5(self):
"""
Tests that the md5 function returns the correct hash for a given string and encoding.
"""
self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3")
self.assertEqual(
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
)
class TestDropLongRLSeq(unittest.TestCase):
"""
Tests for the drop_long_rl_seq function.
"""
def setUp(self):
# Mock tokenizer that returns length based on input string length
"""
Sets up a mock tokenizer and sequence length for RL sequence length tests.
The mock tokenizer simulates tokenization by returning input IDs equal to the input string's length and decodes tokens as repeated "x" characters. The sequence length limit is set to 20.
"""
self.tokenizer = MagicMock()
def side_effect_func(
text, add_special_tokens=False
): # pylint: disable=unused-argument
"""
Simulates tokenization by returning input IDs as a sequence of integers equal to the input text length.
Args:
text: The input string to tokenize.
add_special_tokens: Ignored parameter included for interface compatibility.
Returns:
A dictionary with 'input_ids' as a list of integers from 0 to len(text) - 1.
"""
return {"input_ids": list(range(len(text)))}
self.tokenizer.side_effect = side_effect_func
self.tokenizer.decode = lambda tokens, skip_special_tokens: "".join(
["x"] * len(tokens)
) # pylint: disable=unused-argument
self.sequence_len = 20
def test_dpo_drop_mode_valid(self):
"""
Tests that drop_long_rl_seq returns True in drop mode for a DPO sample within the sequence length limit.
"""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 7,
"rejected": "r" * 6,
} # 5+7=12 <= 20, 5+6=11 <= 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertTrue(result)
def test_dpo_drop_mode_invalid_chosen(self):
"""
Tests that in DPO drop mode, a sample is rejected when the prompt and chosen lengths exceed the sequence limit.
"""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 16,
"rejected": "r" * 6,
} # 5+16=21 > 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertFalse(result)
def test_dpo_drop_mode_invalid_rejected(self):
"""
Tests that in DPO drop mode, a sample is rejected when the prompt plus rejected response exceeds the sequence length limit.
"""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 7,
"rejected": "r" * 16,
} # 5+16=21 > 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertFalse(result)
def test_dpo_truncate_mode_no_truncation_needed(self):
"""
Verifies that in DPO truncate mode, samples within the sequence length limit are returned unchanged.
"""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 7,
"rejected": "r" * 6,
} # 5+7=12 <= 20, 5+6=11 <= 20
original_sample = sample.copy()
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(
result, original_sample
) # Should return the original sample unchanged
def test_dpo_truncate_mode_prompt_too_long(self):
"""
Tests that in DPO truncate mode, if the prompt exceeds the sequence length limit,
the original sample is returned unchanged.
"""
sample = {"prompt": "p" * 25, "chosen": "c" * 7, "rejected": "r" * 6}
original_sample = sample.copy()
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
# Even though truncation isn't possible, the function should return the original sample
# for the map operation, assuming downstream filtering will catch it.
self.assertEqual(result, original_sample)
def test_dpo_truncate_mode_chosen_truncated(self):
"""
Tests that in DPO truncate mode, only the 'chosen' field is truncated when it exceeds the allowed sequence length, while 'prompt' and 'rejected' remain unchanged.
"""
prompt_len = 5
max_resp_len = self.sequence_len - prompt_len # 20 - 5 = 15
sample = {
"prompt": "p" * prompt_len,
"chosen": "c" * 18,
"rejected": "r" * 10,
} # 5+18=23 > 20, 5+10=15 <= 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 15
self.assertEqual(
result["chosen"], "x" * max_resp_len
) # Check decoded truncated value
self.assertEqual(len(result["rejected"]), 10) # Unchanged
def test_dpo_truncate_mode_rejected_truncated(self):
"""
Tests that in DPO truncate mode, only the 'rejected' field is truncated when it exceeds the sequence length limit, while 'prompt' and 'chosen' remain unchanged.
"""
prompt_len = 5
max_resp_len = self.sequence_len - prompt_len # 15
sample = {
"prompt": "p" * prompt_len,
"chosen": "c" * 10,
"rejected": "r" * 18,
} # 5+10=15 <= 20, 5+18=23 > 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["chosen"]), 10) # Unchanged
self.assertEqual(len(result["rejected"]), max_resp_len) # Truncated to 15
self.assertEqual(
result["rejected"], "x" * max_resp_len
) # Check decoded truncated value
def test_dpo_truncate_mode_both_truncated(self):
"""
Tests that in DPO truncate mode, both 'chosen' and 'rejected' fields are truncated when their combined lengths with the prompt exceed the sequence limit.
Verifies that both fields are truncated to fit within the allowed response length and replaced with decoded placeholder content.
"""
prompt_len = 8
max_resp_len = self.sequence_len - prompt_len # 20 - 8 = 12
sample = {
"prompt": "p" * prompt_len,
"chosen": "c" * 15,
"rejected": "r" * 14,
} # 8+15=23 > 20, 8+14=22 > 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 12
self.assertEqual(result["chosen"], "x" * max_resp_len)
self.assertEqual(len(result["rejected"]), max_resp_len) # Truncated to 12
self.assertEqual(result["rejected"], "x" * max_resp_len)
def test_dpo_truncate_mode_no_truncation_needed_but_long(self):
"""
Tests DPO truncate mode where only the overlong response is truncated.
Verifies that when the prompt plus one response exceeds the sequence length, only the response exceeding the maximum allowed length is truncated, while the other remains unchanged.
"""
# This tests the case where len(chosen) <= max_resp_len and len(rejected) <= max_resp_len
# but the initial check failed because e.g. prompt + chosen > sequence_len
# The current logic *will* truncate if len(chosen) > max_resp_len.
# Let's test a case where one is slightly too long causing the initial fail,
# but the other fits *within* the max_response_len, so only one gets truncated.
prompt_len = 10
max_resp_len = self.sequence_len - prompt_len # 10
sample = {
"prompt": "p" * prompt_len,
"chosen": "c" * 11,
"rejected": "r" * 9,
} # 10+11=21 > 20, 10+9=19 <= 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 10
self.assertEqual(result["chosen"], "x" * max_resp_len)
self.assertEqual(len(result["rejected"]), 9) # Unchanged, as 9 <= 10
# Add similar tests for KTO if needed, checking prompt + completion length
def test_kto_drop_mode_valid(self):
"""
Tests that drop_long_rl_seq returns True for a KTO sample within the sequence length limit.
"""
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertTrue(result)
def test_kto_drop_mode_invalid(self):
"""
Tests that drop_long_rl_seq returns False when a KTO sample exceeds the sequence length limit in drop mode.
"""
sample = {"prompt": "p" * 5, "completion": "c" * 16} # 5+16=21 > 20
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertFalse(result)
def test_kto_truncate_mode_no_truncation_needed(self):
"""
Tests that KTO truncate mode returns the original sample unchanged when the combined prompt and completion length does not exceed the sequence limit.
"""
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
original_sample = sample.copy()
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(result, original_sample)
def test_kto_truncate_mode_prompt_too_long(self):
"""
Tests that in KTO truncate mode, if the prompt exceeds the sequence length limit, the original sample is returned unchanged.
"""
sample = {"prompt": "p" * 25, "completion": "c" * 7}
original_sample = sample.copy()
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(result, original_sample) # Returns original sample
def test_kto_truncate_mode_completion_truncated(self):
"""
Tests that in KTO truncate mode, the completion is truncated when the combined prompt and completion exceed the sequence length limit.
Verifies that the prompt remains unchanged and the completion is truncated to fit within the allowed length, with the truncated completion replaced by decoded "x" characters.
"""
prompt_len = 8
max_comp_len = self.sequence_len - prompt_len # 20 - 8 = 12
sample = {"prompt": "p" * prompt_len, "completion": "c" * 15} # 8+15=23 > 20
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["completion"]), max_comp_len) # Truncated to 12
self.assertEqual(result["completion"], "x" * max_comp_len)
def test_missing_keys_dpo(self):
"""
Tests that a ValueError is raised when required keys are missing for DPO samples.
Verifies that the function raises an error if the sample does not contain 'chosen' and 'rejected' keys.
"""
sample = {"prompt": "p"}
with self.assertRaisesRegex(
ValueError, "Prompt, chosen and rejected keys are required"
):
drop_long_rl_seq(sample, "dpo", self.tokenizer, self.sequence_len)
def test_missing_keys_kto(self):
"""
Tests that a ValueError is raised when required keys are missing for RL type "kto".
Verifies that calling drop_long_rl_seq with a sample missing the "completion" key raises
a ValueError with the expected error message.
"""
sample = {"prompt": "p"}
with self.assertRaisesRegex(
ValueError, "Prompt and completion keys are required"
):
drop_long_rl_seq(sample, "kto", self.tokenizer, self.sequence_len)
def test_unknown_rl_type(self):
"""
Tests that a ValueError is raised when an unknown RL type is provided to drop_long_rl_seq.
"""
sample = {}
with self.assertRaisesRegex(ValueError, "Unknown RL type"):
drop_long_rl_seq(sample, "xyz", self.tokenizer, self.sequence_len)
# GRPO test - current implementation always passes
def test_grpo_drop(self):
"""
Tests that drop_long_rl_seq in GRPO drop mode always returns True, regardless of input.
"""
sample = {}
result = drop_long_rl_seq(
sample, "grpo", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertTrue(result)
def test_grpo_truncate(self):
"""
Tests that in truncate mode for RL type "grpo", the original sample is returned unchanged.
"""
sample = {"a": 1}
result = drop_long_rl_seq(
sample, "grpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(result, sample)
if __name__ == "__main__":
unittest.main()

175
tests/test_trainer_utils.py Normal file
View File

@@ -0,0 +1,175 @@
"""Module containing tests for trainer utility functions."""
import unittest
from functools import partial
from axolotl.utils.trainer import truncate_or_drop_long_seq
# Test cases for truncate_or_drop_long_seq
class TestTruncateOrDropLongSeq(unittest.TestCase):
"""
Test suite for truncate_or_drop_long_seq function.
"""
def setUp(self):
# Example sequence length settings
"""
Sets up default sequence length parameters for the test cases.
"""
self.sequence_len = 10
self.min_sequence_len = 3
def test_drop_mode_single(self):
"""
Verifies that 'drop' mode correctly filters single sequence examples based on length.
Tests that sequences shorter than the minimum, longer than the maximum, or empty are dropped,
while sequences within the valid length range are kept.
"""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="drop",
)
# Too short
sample_short = {"input_ids": [1, 2]}
self.assertFalse(handler(sample_short))
# Too long
sample_long = {"input_ids": list(range(self.sequence_len + 1))}
self.assertFalse(handler(sample_long))
# Just right
sample_ok = {"input_ids": list(range(self.min_sequence_len))}
self.assertTrue(handler(sample_ok))
# Empty
sample_empty = {"input_ids": []}
self.assertFalse(handler(sample_empty))
def test_truncate_mode_single(self):
"""
Tests that 'truncate_or_drop_long_seq' correctly truncates or preserves single examples in "truncate" mode.
Verifies that sequences longer than the maximum length are truncated, while sequences that are too short, empty, or within the valid range remain unchanged.
"""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="truncate",
)
# Too short (should still be dropped implicitly by filter/map logic upstream,
# but the function itself might return the sample or False based on impl.)
# Current impl returns the original sample for map if too short, assuming upstream filters.
# Let's refine this test - the function *itself* returns the sample if too short when truncating.
sample_short = {"input_ids": [1, 2], "labels": [1, 2]}
result_short = handler(sample_short)
self.assertEqual(result_short["input_ids"], [1, 2]) # Unchanged
# Too long
original_long = list(range(self.sequence_len + 5))
sample_long = {"input_ids": list(original_long), "labels": list(original_long)}
result_long = handler(sample_long)
self.assertEqual(len(result_long["input_ids"]), self.sequence_len)
self.assertEqual(result_long["input_ids"], list(range(self.sequence_len)))
self.assertEqual(len(result_long["labels"]), self.sequence_len)
self.assertEqual(result_long["labels"], list(range(self.sequence_len)))
# Just right
sample_ok = {
"input_ids": list(range(self.min_sequence_len)),
"labels": list(range(self.min_sequence_len)),
}
result_ok = handler(sample_ok)
self.assertEqual(len(result_ok["input_ids"]), self.min_sequence_len)
self.assertEqual(result_ok, sample_ok) # Should be unchanged
# Empty
sample_empty = {"input_ids": [], "labels": []}
result_empty = handler(sample_empty)
self.assertEqual(result_empty, sample_empty) # Unchanged
def test_drop_mode_batched(self):
"""
Tests that the "drop" handling mode correctly filters batched input sequences based on length constraints.
Verifies that sequences shorter than the minimum length, longer than the maximum length, or empty are dropped (returns False), while sequences within the valid range are kept (returns True).
"""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="drop",
)
sample = {
"input_ids": [
[1, 2], # Too short
list(range(self.sequence_len + 1)), # Too long
list(range(self.sequence_len)), # OK (len = 10)
list(range(self.min_sequence_len)), # OK (len = 3)
[], # Empty
]
}
expected = [False, False, True, True, False]
self.assertEqual(handler(sample), expected)
def test_truncate_mode_batched(self):
"""
Tests that batched examples are correctly truncated in "truncate" mode.
Verifies that sequences in both "input_ids" and "labels" longer than the maximum
allowed length are truncated, while sequences that are too short or empty remain
unchanged.
"""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="truncate",
)
sample = {
"input_ids": [
[1, 2], # Too short
list(range(self.sequence_len + 5)), # Too long
list(range(self.sequence_len)), # OK
list(range(self.min_sequence_len)), # OK
[], # Empty
],
"labels": [ # Add labels to test truncation
[1, 2],
list(range(self.sequence_len + 5)),
list(range(self.sequence_len)),
list(range(self.min_sequence_len)),
[],
],
}
result = handler(sample)
# Expected results after truncation (too short and empty remain unchanged by this function)
expected_input_ids = [
[1, 2], # Unchanged (too short)
list(range(self.sequence_len)), # Truncated
list(range(self.sequence_len)), # Unchanged (OK)
list(range(self.min_sequence_len)), # Unchanged (OK)
[], # Unchanged (Empty)
]
expected_labels = [
[1, 2], # Unchanged (too short)
list(range(self.sequence_len)), # Truncated
list(range(self.sequence_len)), # Unchanged (OK)
list(range(self.min_sequence_len)), # Unchanged (OK)
[], # Unchanged (Empty)
]
self.assertEqual(result["input_ids"], expected_input_ids)
self.assertEqual(result["labels"], expected_labels)
if __name__ == "__main__":
unittest.main()