📝 Add docstrings to 775-option-to-drop-vs-truncate-on-rows-longer-than-context-length

Docstrings generation was requested by @mhenrichsen.

* https://github.com/axolotl-ai-cloud/axolotl/pull/2662#issuecomment-2883401776

The following files were modified:

* `src/axolotl/utils/data/pretraining.py`
* `src/axolotl/utils/data/rl.py`
* `src/axolotl/utils/data/utils.py`
* `src/axolotl/utils/trainer.py`
* `tests/test_data.py`
* `tests/test_trainer_utils.py`
This commit is contained in:
coderabbitai[bot]
2025-05-15 11:02:45 +00:00
committed by GitHub
parent 5d7a61576d
commit e23a5c9fda
6 changed files with 215 additions and 38 deletions

View File

@@ -251,6 +251,22 @@ def encode_packed_pretraining(
# pylint: disable=duplicate-code
# tokenize all the examples
# rows get split with stride (overlap)
"""
Encodes and packs input examples into fixed-length batches for pretraining with optional multipack attention.
Wraps and processes input examples into a dataset, applies sequence packing with configurable overflow handling, and batches the data using a multipack sampler. Each batch is collated and features are aggregated into lists keyed by feature name.
Args:
collate_fn: Function to collate individual feature dictionaries into batch tensors.
ds_wrapper: Callable that wraps a Hugging Face Dataset for further processing.
examples: Dictionary of input examples to encode and pack.
max_seq_length: Maximum sequence length for each packed sequence.
batch_size: Number of sequences to pack per batch.
multipack_attn: If True, enables multipack attention and drops attention masks.
Returns:
Dictionary where each key is a feature name and each value is a list of packed feature tensors.
"""
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
train_dataset = process_pretraining_datasets_for_packing(

View File

@@ -85,6 +85,25 @@ def drop_long_rl_seq(
sequence_len,
handling="drop", # Use the default handling mode
):
"""
Handles samples exceeding a maximum sequence length for various RL dataset types by either truncating or dropping them.
Depending on the RL type and the `handling` mode, this function either truncates response fields to fit within the specified sequence length or determines whether the sample should be dropped. For DPO, IPO, ORPO, and SIMPO types, both "chosen" and "rejected" responses are considered; for KTO, the "completion" is considered. For GRPO, samples are always retained. If truncation is not possible (e.g., the prompt alone exceeds the limit), the sample is returned unchanged for mapping, or dropped during filtering.
Args:
sample: A dictionary representing a single dataset sample.
rl: The RLType indicating the dataset type.
tokenizer: The tokenizer used to compute token lengths and perform truncation.
sequence_len: The maximum allowed sequence length.
handling: Specifies how to handle overlong sequences ("drop" or "truncate").
Returns:
For "truncate": The modified sample with responses truncated as needed, or the original sample if truncation is not possible.
For "drop": True if the sample fits within the sequence length, otherwise False.
Raises:
ValueError: If required keys are missing for the specified RL type, or if the RL type is unknown.
"""
result = None
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
@@ -210,6 +229,17 @@ def drop_long_rl_seq(
def load_prepare_preference_datasets(cfg):
"""
Loads, preprocesses, and prepares preference datasets for RL training and evaluation.
This function orchestrates the loading, transformation, sequence length handling, optional deduplication, and caching of datasets for Direct Preference Optimization (DPO) and related RL types. It supports configurable handling of overlong sequences (dropping or truncating), applies dataset-specific transformations, and manages train/validation/test splits as needed.
Args:
cfg: Configuration object specifying dataset sources, RL type, tokenizer, sequence length, and processing options.
Returns:
A tuple containing the prepared training and evaluation datasets.
"""
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
use_auth_token = _cfg.hf_use_auth_token

View File

@@ -161,6 +161,18 @@ def deduplicate_and_log_datasets(
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
"""
Processes a dataset to handle sequences exceeding a configured maximum length by either truncating or dropping them.
If the dataset lacks an "input_ids" column, the function returns the dataset unchanged. The handling mode is determined by the configuration parameter "sequence_len_overflow_handling", defaulting to "drop". In "truncate" mode, sequences longer than the maximum length are truncated; in "drop" mode, such sequences are removed from the dataset. The function logs information about sequence lengths and the number of samples affected when applicable.
Args:
dataset: The Huggingface Dataset to process.
cfg: Configuration object specifying sequence length parameters and handling mode.
Returns:
The processed dataset with long sequences either truncated or dropped according to the configuration.
"""
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling."

View File

@@ -207,10 +207,18 @@ def add_length(sample):
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
"""
Drop samples whose sequence length is either too long (> sequence_len)
or too short (< min_sequence_len).
Works for both single-example (list[int]) or batched (list[list[int]]).
Determines whether samples should be kept based on sequence length constraints.
For a single example or a batch, returns True (or a list of booleans) if each sequence's length is within the specified range; otherwise, returns False (or a list with False for out-of-range sequences).
Args:
sample: A dictionary containing "input_ids" as a list of ints or a list of lists of ints.
sequence_len: Maximum allowed sequence length (inclusive).
min_sequence_len: Minimum allowed sequence length (inclusive).
Returns:
True if the single example is within the length range, False otherwise.
For batched input, returns a list of booleans indicating which sequences are within the range.
"""
min_sequence_len = min_sequence_len or 2
@@ -239,17 +247,18 @@ def truncate_or_drop_long_seq(
sample, sequence_len=2048, min_sequence_len=2, handling="drop"
):
"""
Either drop or truncate samples whose sequence length is either too long (> sequence_len)
or too short (< min_sequence_len).
If handling is "drop":
- Samples that are too short or too long will be dropped
If handling is "truncate":
- Samples that are too short will still be dropped
- Samples that are too long will be truncated to sequence_len
Works for both single-example (list[int]) or batched (list[list[int]]).
Returns either a boolean/list of booleans (for drop mode) or the modified sample (for truncate mode).
Drops or truncates samples based on sequence length constraints.
If handling is "drop", returns a boolean or list of booleans indicating whether each sample's sequence length is within the specified range. If handling is "truncate", returns the sample with sequences longer than sequence_len truncated and sequences shorter than min_sequence_len omitted. Supports both single-example and batched inputs.
Args:
sample: A dictionary containing at least an "input_ids" field, representing either a single sequence or a batch of sequences.
sequence_len: Maximum allowed sequence length.
min_sequence_len: Minimum allowed sequence length.
handling: "drop" to filter out samples outside the range, "truncate" to truncate long sequences.
Returns:
In "drop" mode, a boolean or list of booleans indicating which samples to keep. In "truncate" mode, the modified sample with sequences truncated as needed.
"""
min_sequence_len = min_sequence_len or 2
result = None
@@ -344,6 +353,11 @@ def truncate_or_drop_long_seq(
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
"""
Prepares training and evaluation datasets for sample packing and model-specific requirements.
Removes unnecessary columns based on model type, filters out samples with no trainable tokens, and optionally adds length or position ID columns for sample packing or PoSE techniques. Returns the processed training and evaluation datasets.
"""
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"]
if drop_attn_mask:
LOG.info("dropping attention_mask column")
@@ -485,6 +499,21 @@ def process_pretraining_datasets_for_packing(
handling="drop",
):
# Define the function to use for handling sequences based on the mode
"""
Processes a pretraining dataset by truncating or dropping sequences based on length.
Depending on the handling mode, sequences longer than `sequence_len` are either truncated or dropped, and sequences shorter than `min_sequence_len` are dropped. Optionally adds position IDs and removes the attention mask column.
Args:
train_dataset: The dataset to process.
sequence_len: Maximum allowed sequence length.
skip_position_ids: If False, adds position IDs to each sample.
drop_attention_mask: If True, removes the attention mask column.
handling: "drop" to remove long sequences, "truncate" to truncate them.
Returns:
The processed dataset with sequences handled according to the specified mode.
"""
seq_handler_fn = partial(
truncate_or_drop_long_seq,
sequence_len=sequence_len,

View File

@@ -60,6 +60,9 @@ class TestEncodePretraining(unittest.TestCase):
self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id)
def test_md5(self):
"""
Tests that the md5 function returns the correct hash for a given string and encoding.
"""
self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3")
self.assertEqual(
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
@@ -73,11 +76,26 @@ class TestDropLongRLSeq(unittest.TestCase):
def setUp(self):
# Mock tokenizer that returns length based on input string length
"""
Sets up a mock tokenizer and sequence length for RL sequence length tests.
The mock tokenizer simulates tokenization by returning input IDs equal to the input string's length and decodes tokens as repeated "x" characters. The sequence length limit is set to 20.
"""
self.tokenizer = MagicMock()
def side_effect_func(
text, add_special_tokens=False
): # pylint: disable=unused-argument
"""
Simulates tokenization by returning input IDs as a sequence of integers equal to the input text length.
Args:
text: The input string to tokenize.
add_special_tokens: Ignored parameter included for interface compatibility.
Returns:
A dictionary with 'input_ids' as a list of integers from 0 to len(text) - 1.
"""
return {"input_ids": list(range(len(text)))}
self.tokenizer.side_effect = side_effect_func
@@ -88,7 +106,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.sequence_len = 20
def test_dpo_drop_mode_valid(self):
"""Test DPO drop mode with a valid sample."""
"""
Tests that drop_long_rl_seq returns True in drop mode for a DPO sample within the sequence length limit.
"""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 7,
@@ -100,7 +120,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertTrue(result)
def test_dpo_drop_mode_invalid_chosen(self):
"""Test DPO drop mode with chosen too long."""
"""
Tests that in DPO drop mode, a sample is rejected when the prompt and chosen lengths exceed the sequence limit.
"""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 16,
@@ -112,7 +134,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertFalse(result)
def test_dpo_drop_mode_invalid_rejected(self):
"""Test DPO drop mode with rejected too long."""
"""
Tests that in DPO drop mode, a sample is rejected when the prompt plus rejected response exceeds the sequence length limit.
"""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 7,
@@ -124,7 +148,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertFalse(result)
def test_dpo_truncate_mode_no_truncation_needed(self):
"""Test DPO truncate mode when no truncation is needed."""
"""
Verifies that in DPO truncate mode, samples within the sequence length limit are returned unchanged.
"""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 7,
@@ -139,7 +165,10 @@ class TestDropLongRLSeq(unittest.TestCase):
) # Should return the original sample unchanged
def test_dpo_truncate_mode_prompt_too_long(self):
"""Test DPO truncate mode when the prompt itself is too long."""
"""
Tests that in DPO truncate mode, if the prompt exceeds the sequence length limit,
the original sample is returned unchanged.
"""
sample = {"prompt": "p" * 25, "chosen": "c" * 7, "rejected": "r" * 6}
original_sample = sample.copy()
result = drop_long_rl_seq(
@@ -150,7 +179,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertEqual(result, original_sample)
def test_dpo_truncate_mode_chosen_truncated(self):
"""Test DPO truncate mode when only 'chosen' needs truncation."""
"""
Tests that in DPO truncate mode, only the 'chosen' field is truncated when it exceeds the allowed sequence length, while 'prompt' and 'rejected' remain unchanged.
"""
prompt_len = 5
max_resp_len = self.sequence_len - prompt_len # 20 - 5 = 15
sample = {
@@ -169,7 +200,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertEqual(len(result["rejected"]), 10) # Unchanged
def test_dpo_truncate_mode_rejected_truncated(self):
"""Test DPO truncate mode when only 'rejected' needs truncation."""
"""
Tests that in DPO truncate mode, only the 'rejected' field is truncated when it exceeds the sequence length limit, while 'prompt' and 'chosen' remain unchanged.
"""
prompt_len = 5
max_resp_len = self.sequence_len - prompt_len # 15
sample = {
@@ -188,7 +221,11 @@ class TestDropLongRLSeq(unittest.TestCase):
) # Check decoded truncated value
def test_dpo_truncate_mode_both_truncated(self):
"""Test DPO truncate mode when both 'chosen' and 'rejected' need truncation."""
"""
Tests that in DPO truncate mode, both 'chosen' and 'rejected' fields are truncated when their combined lengths with the prompt exceed the sequence limit.
Verifies that both fields are truncated to fit within the allowed response length and replaced with decoded placeholder content.
"""
prompt_len = 8
max_resp_len = self.sequence_len - prompt_len # 20 - 8 = 12
sample = {
@@ -206,7 +243,11 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertEqual(result["rejected"], "x" * max_resp_len)
def test_dpo_truncate_mode_no_truncation_needed_but_long(self):
"""Test DPO truncate mode where individual parts fit but combined don't, but no truncation happens."""
"""
Tests DPO truncate mode where only the overlong response is truncated.
Verifies that when the prompt plus one response exceeds the sequence length, only the response exceeding the maximum allowed length is truncated, while the other remains unchanged.
"""
# This tests the case where len(chosen) <= max_resp_len and len(rejected) <= max_resp_len
# but the initial check failed because e.g. prompt + chosen > sequence_len
# The current logic *will* truncate if len(chosen) > max_resp_len.
@@ -230,7 +271,9 @@ class TestDropLongRLSeq(unittest.TestCase):
# Add similar tests for KTO if needed, checking prompt + completion length
def test_kto_drop_mode_valid(self):
"""Test KTO drop mode with a valid sample."""
"""
Tests that drop_long_rl_seq returns True for a KTO sample within the sequence length limit.
"""
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
@@ -238,7 +281,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertTrue(result)
def test_kto_drop_mode_invalid(self):
"""Test KTO drop mode with an invalid sample."""
"""
Tests that drop_long_rl_seq returns False when a KTO sample exceeds the sequence length limit in drop mode.
"""
sample = {"prompt": "p" * 5, "completion": "c" * 16} # 5+16=21 > 20
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
@@ -246,7 +291,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertFalse(result)
def test_kto_truncate_mode_no_truncation_needed(self):
"""Test KTO truncate mode when no truncation is needed."""
"""
Tests that KTO truncate mode returns the original sample unchanged when the combined prompt and completion length does not exceed the sequence limit.
"""
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
original_sample = sample.copy()
result = drop_long_rl_seq(
@@ -255,7 +302,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertEqual(result, original_sample)
def test_kto_truncate_mode_prompt_too_long(self):
"""Test KTO truncate mode when the prompt itself is too long."""
"""
Tests that in KTO truncate mode, if the prompt exceeds the sequence length limit, the original sample is returned unchanged.
"""
sample = {"prompt": "p" * 25, "completion": "c" * 7}
original_sample = sample.copy()
result = drop_long_rl_seq(
@@ -264,7 +313,11 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertEqual(result, original_sample) # Returns original sample
def test_kto_truncate_mode_completion_truncated(self):
"""Test KTO truncate mode when completion needs truncation."""
"""
Tests that in KTO truncate mode, the completion is truncated when the combined prompt and completion exceed the sequence length limit.
Verifies that the prompt remains unchanged and the completion is truncated to fit within the allowed length, with the truncated completion replaced by decoded "x" characters.
"""
prompt_len = 8
max_comp_len = self.sequence_len - prompt_len # 20 - 8 = 12
sample = {"prompt": "p" * prompt_len, "completion": "c" * 15} # 8+15=23 > 20
@@ -276,7 +329,11 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertEqual(result["completion"], "x" * max_comp_len)
def test_missing_keys_dpo(self):
"""Test ValueError raised if keys missing for DPO."""
"""
Tests that a ValueError is raised when required keys are missing for DPO samples.
Verifies that the function raises an error if the sample does not contain 'chosen' and 'rejected' keys.
"""
sample = {"prompt": "p"}
with self.assertRaisesRegex(
ValueError, "Prompt, chosen and rejected keys are required"
@@ -284,7 +341,12 @@ class TestDropLongRLSeq(unittest.TestCase):
drop_long_rl_seq(sample, "dpo", self.tokenizer, self.sequence_len)
def test_missing_keys_kto(self):
"""Test ValueError raised if keys missing for KTO."""
"""
Tests that a ValueError is raised when required keys are missing for RL type "kto".
Verifies that calling drop_long_rl_seq with a sample missing the "completion" key raises
a ValueError with the expected error message.
"""
sample = {"prompt": "p"}
with self.assertRaisesRegex(
ValueError, "Prompt and completion keys are required"
@@ -292,14 +354,18 @@ class TestDropLongRLSeq(unittest.TestCase):
drop_long_rl_seq(sample, "kto", self.tokenizer, self.sequence_len)
def test_unknown_rl_type(self):
"""Test ValueError raised for unknown RL type."""
"""
Tests that a ValueError is raised when an unknown RL type is provided to drop_long_rl_seq.
"""
sample = {}
with self.assertRaisesRegex(ValueError, "Unknown RL type"):
drop_long_rl_seq(sample, "xyz", self.tokenizer, self.sequence_len)
# GRPO test - current implementation always passes
def test_grpo_drop(self):
"""Test GRPO drop mode (currently always True)."""
"""
Tests that drop_long_rl_seq in GRPO drop mode always returns True, regardless of input.
"""
sample = {}
result = drop_long_rl_seq(
sample, "grpo", self.tokenizer, self.sequence_len, handling="drop"
@@ -307,7 +373,9 @@ class TestDropLongRLSeq(unittest.TestCase):
self.assertTrue(result)
def test_grpo_truncate(self):
"""Test GRPO truncate mode (currently returns original sample)."""
"""
Tests that in truncate mode for RL type "grpo", the original sample is returned unchanged.
"""
sample = {"a": 1}
result = drop_long_rl_seq(
sample, "grpo", self.tokenizer, self.sequence_len, handling="truncate"

View File

@@ -14,11 +14,19 @@ class TestTruncateOrDropLongSeq(unittest.TestCase):
def setUp(self):
# Example sequence length settings
"""
Sets up default sequence length parameters for the test cases.
"""
self.sequence_len = 10
self.min_sequence_len = 3
def test_drop_mode_single(self):
"""Test drop mode with single examples."""
"""
Verifies that 'drop' mode correctly filters single sequence examples based on length.
Tests that sequences shorter than the minimum, longer than the maximum, or empty are dropped,
while sequences within the valid length range are kept.
"""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
@@ -43,7 +51,11 @@ class TestTruncateOrDropLongSeq(unittest.TestCase):
self.assertFalse(handler(sample_empty))
def test_truncate_mode_single(self):
"""Test truncate mode with single examples."""
"""
Tests that 'truncate_or_drop_long_seq' correctly truncates or preserves single examples in "truncate" mode.
Verifies that sequences longer than the maximum length are truncated, while sequences that are too short, empty, or within the valid range remain unchanged.
"""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
@@ -83,7 +95,11 @@ class TestTruncateOrDropLongSeq(unittest.TestCase):
self.assertEqual(result_empty, sample_empty) # Unchanged
def test_drop_mode_batched(self):
"""Test drop mode with batched examples."""
"""
Tests that the "drop" handling mode correctly filters batched input sequences based on length constraints.
Verifies that sequences shorter than the minimum length, longer than the maximum length, or empty are dropped (returns False), while sequences within the valid range are kept (returns True).
"""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
@@ -103,7 +119,13 @@ class TestTruncateOrDropLongSeq(unittest.TestCase):
self.assertEqual(handler(sample), expected)
def test_truncate_mode_batched(self):
"""Test truncate mode with batched examples."""
"""
Tests that batched examples are correctly truncated in "truncate" mode.
Verifies that sequences in both "input_ids" and "labels" longer than the maximum
allowed length are truncated, while sequences that are too short or empty remain
unchanged.
"""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,