📝 Add docstrings to 775-option-to-drop-vs-truncate-on-rows-longer-than-context-length

Docstrings generation was requested by @mhenrichsen.

* https://github.com/axolotl-ai-cloud/axolotl/pull/2662#issuecomment-2883401776

The following files were modified:

* `src/axolotl/utils/data/pretraining.py`
* `src/axolotl/utils/data/rl.py`
* `src/axolotl/utils/data/utils.py`
* `src/axolotl/utils/trainer.py`
* `tests/test_data.py`
* `tests/test_trainer_utils.py`
This commit is contained in:
coderabbitai[bot]
2025-05-15 11:02:45 +00:00
committed by GitHub
parent 5d7a61576d
commit e23a5c9fda
6 changed files with 215 additions and 38 deletions

View File

@@ -251,6 +251,22 @@ def encode_packed_pretraining(
# pylint: disable=duplicate-code
# tokenize all the examples
# rows get split with stride (overlap)
"""
Encodes and packs input examples into fixed-length batches for pretraining with optional multipack attention.
Wraps and processes input examples into a dataset, applies sequence packing with configurable overflow handling, and batches the data using a multipack sampler. Each batch is collated and features are aggregated into lists keyed by feature name.
Args:
collate_fn: Function to collate individual feature dictionaries into batch tensors.
ds_wrapper: Callable that wraps a Hugging Face Dataset for further processing.
examples: Dictionary of input examples to encode and pack.
max_seq_length: Maximum sequence length for each packed sequence.
batch_size: Number of sequences to pack per batch.
multipack_attn: If True, enables multipack attention and drops attention masks.
Returns:
Dictionary where each key is a feature name and each value is a list of packed feature tensors.
"""
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
train_dataset = process_pretraining_datasets_for_packing(

View File

@@ -85,6 +85,25 @@ def drop_long_rl_seq(
sequence_len,
handling="drop", # Use the default handling mode
):
"""
Handles samples exceeding a maximum sequence length for various RL dataset types by either truncating or dropping them.
Depending on the RL type and the `handling` mode, this function either truncates response fields to fit within the specified sequence length or determines whether the sample should be dropped. For DPO, IPO, ORPO, and SIMPO types, both "chosen" and "rejected" responses are considered; for KTO, the "completion" is considered. For GRPO, samples are always retained. If truncation is not possible (e.g., the prompt alone exceeds the limit), the sample is returned unchanged for mapping, or dropped during filtering.
Args:
sample: A dictionary representing a single dataset sample.
rl: The RLType indicating the dataset type.
tokenizer: The tokenizer used to compute token lengths and perform truncation.
sequence_len: The maximum allowed sequence length.
handling: Specifies how to handle overlong sequences ("drop" or "truncate").
Returns:
For "truncate": The modified sample with responses truncated as needed, or the original sample if truncation is not possible.
For "drop": True if the sample fits within the sequence length, otherwise False.
Raises:
ValueError: If required keys are missing for the specified RL type, or if the RL type is unknown.
"""
result = None
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
@@ -210,6 +229,17 @@ def drop_long_rl_seq(
def load_prepare_preference_datasets(cfg):
"""
Loads, preprocesses, and prepares preference datasets for RL training and evaluation.
This function orchestrates the loading, transformation, sequence length handling, optional deduplication, and caching of datasets for Direct Preference Optimization (DPO) and related RL types. It supports configurable handling of overlong sequences (dropping or truncating), applies dataset-specific transformations, and manages train/validation/test splits as needed.
Args:
cfg: Configuration object specifying dataset sources, RL type, tokenizer, sequence length, and processing options.
Returns:
A tuple containing the prepared training and evaluation datasets.
"""
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
use_auth_token = _cfg.hf_use_auth_token

View File

@@ -161,6 +161,18 @@ def deduplicate_and_log_datasets(
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
"""
Processes a dataset to handle sequences exceeding a configured maximum length by either truncating or dropping them.
If the dataset lacks an "input_ids" column, the function returns the dataset unchanged. The handling mode is determined by the configuration parameter "sequence_len_overflow_handling", defaulting to "drop". In "truncate" mode, sequences longer than the maximum length are truncated; in "drop" mode, such sequences are removed from the dataset. The function logs information about sequence lengths and the number of samples affected when applicable.
Args:
dataset: The Huggingface Dataset to process.
cfg: Configuration object specifying sequence length parameters and handling mode.
Returns:
The processed dataset with long sequences either truncated or dropped according to the configuration.
"""
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling."

View File

@@ -207,10 +207,18 @@ def add_length(sample):
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
"""
Drop samples whose sequence length is either too long (> sequence_len)
or too short (< min_sequence_len).
Works for both single-example (list[int]) or batched (list[list[int]]).
Determines whether samples should be kept based on sequence length constraints.
For a single example or a batch, returns True (or a list of booleans) if each sequence's length is within the specified range; otherwise, returns False (or a list with False for out-of-range sequences).
Args:
sample: A dictionary containing "input_ids" as a list of ints or a list of lists of ints.
sequence_len: Maximum allowed sequence length (inclusive).
min_sequence_len: Minimum allowed sequence length (inclusive).
Returns:
True if the single example is within the length range, False otherwise.
For batched input, returns a list of booleans indicating which sequences are within the range.
"""
min_sequence_len = min_sequence_len or 2
@@ -239,17 +247,18 @@ def truncate_or_drop_long_seq(
sample, sequence_len=2048, min_sequence_len=2, handling="drop"
):
"""
Either drop or truncate samples whose sequence length is either too long (> sequence_len)
or too short (< min_sequence_len).
If handling is "drop":
- Samples that are too short or too long will be dropped
If handling is "truncate":
- Samples that are too short will still be dropped
- Samples that are too long will be truncated to sequence_len
Works for both single-example (list[int]) or batched (list[list[int]]).
Returns either a boolean/list of booleans (for drop mode) or the modified sample (for truncate mode).
Drops or truncates samples based on sequence length constraints.
If handling is "drop", returns a boolean or list of booleans indicating whether each sample's sequence length is within the specified range. If handling is "truncate", returns the sample with sequences longer than sequence_len truncated and sequences shorter than min_sequence_len omitted. Supports both single-example and batched inputs.
Args:
sample: A dictionary containing at least an "input_ids" field, representing either a single sequence or a batch of sequences.
sequence_len: Maximum allowed sequence length.
min_sequence_len: Minimum allowed sequence length.
handling: "drop" to filter out samples outside the range, "truncate" to truncate long sequences.
Returns:
In "drop" mode, a boolean or list of booleans indicating which samples to keep. In "truncate" mode, the modified sample with sequences truncated as needed.
"""
min_sequence_len = min_sequence_len or 2
result = None
@@ -344,6 +353,11 @@ def truncate_or_drop_long_seq(
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
"""
Prepares training and evaluation datasets for sample packing and model-specific requirements.
Removes unnecessary columns based on model type, filters out samples with no trainable tokens, and optionally adds length or position ID columns for sample packing or PoSE techniques. Returns the processed training and evaluation datasets.
"""
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"]
if drop_attn_mask:
LOG.info("dropping attention_mask column")
@@ -485,6 +499,21 @@ def process_pretraining_datasets_for_packing(
handling="drop",
):
# Define the function to use for handling sequences based on the mode
"""
Processes a pretraining dataset by truncating or dropping sequences based on length.
Depending on the handling mode, sequences longer than `sequence_len` are either truncated or dropped, and sequences shorter than `min_sequence_len` are dropped. Optionally adds position IDs and removes the attention mask column.
Args:
train_dataset: The dataset to process.
sequence_len: Maximum allowed sequence length.
skip_position_ids: If False, adds position IDs to each sample.
drop_attention_mask: If True, removes the attention mask column.
handling: "drop" to remove long sequences, "truncate" to truncate them.
Returns:
The processed dataset with sequences handled according to the specified mode.
"""
seq_handler_fn = partial(
truncate_or_drop_long_seq,
sequence_len=sequence_len,