Compare commits

...

21 Commits

Author SHA1 Message Date
mhenrhcsen
0f2d196476 Remove deprecated configuration files: deleted config.qmd and finetune copy.yml to streamline project structure and eliminate unused resources. 2025-08-12 21:23:34 +02:00
mhenrhcsen
f1a8474400 Remove transscribe.py file and clean up optimizer.py and rl.py for improved formatting and consistency. 2025-08-12 21:20:48 +02:00
mhenrhcsen
dc5887c652 pre-commit: fix rl.py imports/types; add legacy drop_long_rl_seq wrapper; resolve config schema; run formatting 2025-08-12 21:12:07 +02:00
mhenrhcsen
54b542d312 remove unused files 2025-08-12 21:09:40 +02:00
mhenrhcsen
30a89b07b9 Refactor AxolotlInputConfig: clean up sequence_len and sequence_len_overflow_handling fields, ensuring consistent descriptions and removing conflict markers. 2025-08-12 21:03:28 +02:00
mhenrhcsen
746c03b097 Clean up conflict markers; finalize RL data split implementation; fix config schema conflicts; add truncation+post-filter behavior and alias handling 2025-08-12 20:53:28 +02:00
mhenrhcsen
47b3fe8af3 Resolve merge conflicts: unify pretraining utils imports, add alias handling; fix rl.py per new RL dataset API; resolve config schema conflict and add sequence_len_overflow_handling field 2025-08-12 20:45:26 +02:00
mhenrhcsen
f5a3e3529e RL datasets: warn and drop unsalvageable over-length prompts post-truncate; add post-truncate filter; support alias config key 'excess_token_handling' 2025-08-12 20:37:41 +02:00
mhenrichsen
618b008e36 Merge branch 'main' into 775-option-to-drop-vs-truncate-on-rows-longer-than-context-length 2025-05-27 12:31:31 +02:00
mhenrhcsen
5d7a61576d Refactor sequence length overflow handling in pretraining module
- Introduced DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING constant in utils.py.
- Updated encode_packed_pretraining function to use this constant instead of a hardcoded value.
2025-05-15 12:55:09 +02:00
mhenrhcsen
5ecf22b54e Merge branch 'main' of github.com:axolotl-ai-cloud/axolotl into 775-option-to-drop-vs-truncate-on-rows-longer-than-context-length 2025-05-14 13:36:43 +02:00
mhenrhcsen
9c5b8da22f fix merge conflicts 2025-05-14 13:33:42 +02:00
mhenrhcsen
fea6649518 increased test coverage 2025-05-13 08:58:34 +02:00
mhenrhcsen
124ad2b968 lint 2025-05-13 08:35:16 +02:00
mhenrhcsen
767c2340f1 docstring for tests 2025-05-12 22:57:43 +02:00
mhenrhcsen
f6623c34cc Linting fix 2025-05-12 22:53:30 +02:00
mhenrhcsen
5dd8f0b2b8 Fixes comments from winglian 2025-05-12 22:43:15 +02:00
mhenrhcsen
be3c6bbd85 fix linting issues 2025-05-12 14:46:57 +02:00
mhenrhcsen
f07db4f853 Refactor truncation logic in drop_long_rl_seq function
- Simplified the truncation process for chosen and rejected responses to ensure they fit within the specified sequence length while preserving the prompt.
- Improved readability by restructuring the code and removing redundant checks.
- Ensured that the function returns the sample correctly after processing, maintaining compatibility with existing handling options.
2025-05-12 14:40:10 +02:00
mhenrhcsen
17a5838d38 lint 2025-05-12 14:36:43 +02:00
mhenrhcsen
9f68918f13 Implement configurable handling of excess tokens in datasets
- Added `excess_token_handling` option to the configuration, allowing users to choose between "drop" and "truncate" for handling tokens exceeding the maximum sequence length.
- Introduced `truncate_or_drop_long_seq` function to manage both single and batched samples based on the selected handling method.
- Updated relevant dataset processing functions to utilize the new handling option, ensuring backward compatibility with existing "drop" behavior.
- Enhanced logging to reflect truncation actions in dataset processing.

This change improves flexibility in managing sequence lengths during training and evaluation.
2025-05-12 14:08:43 +02:00
8 changed files with 724 additions and 31 deletions

View File

@@ -185,12 +185,12 @@ class OptimizerMixin(Trainer):
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
LOG.info(f"skipped {module}: {skipped / 2 ** 20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
LOG.info(f"skipped: {skipped / 2 ** 20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init

View File

@@ -10,6 +10,7 @@ from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
@@ -259,6 +260,15 @@ def encode_packed_pretraining(
# FIXME using attention mask unpad/pad with trainer and packed pretraining is broken atm
# workaround by using the position id logic for now in trainer
drop_attention_mask=multipack_attn,
# pass through handling mode from config via ds_wrapper function
handling=(
getattr(ds_wrapper, "cfg", {}).get(
"sequence_len_overflow_handling",
getattr(ds_wrapper, "cfg", {}).get(
"excess_token_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING
),
)
),
)
sampler = MultipackBatchSampler(

View File

@@ -122,6 +122,14 @@ def _map_dataset(
return dataset
def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"):
"""
Backward-compatibility wrapper for legacy imports in tests.
Delegates to the new predicate.
"""
return _drop_long_sequences(sample, rl, tokenizer, sequence_len)
def _drop_long_sequences(
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
) -> bool:
@@ -155,11 +163,51 @@ def _drop_long_sequences(
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
return (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len
# Truncate first, then drop if still invalid (although truncate should handle it)
handling_mode = sample.get("sequence_len_overflow_handling", "drop")
if handling_mode == "truncate":
# If both sequences fit, return sample unchanged
if (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len:
result = sample
else:
# Calculate maximum response length that can fit with the prompt
max_response_len = sequence_len - len_prompt
if rl is RLType.KTO:
if max_response_len <= 0:
# Prompt itself exceeds sequence length. Cannot truncate responses to fix it.
LOG.warning(
"Prompt length (%s) exceeds sequence length (%s) for DPO-like sample; dropping",
len_prompt,
sequence_len,
)
result = False
else:
# Truncate the chosen and rejected responses if needed
if len_chosen > max_response_len:
chosen_tokens = tokenizer(chosen, add_special_tokens=False)[
"input_ids"
][:max_response_len]
sample["chosen"] = tokenizer.decode(
chosen_tokens, skip_special_tokens=True
)
if len_rejected > max_response_len:
rejected_tokens = tokenizer(rejected, add_special_tokens=False)[
"input_ids"
][:max_response_len]
sample["rejected"] = tokenizer.decode(
rejected_tokens, skip_special_tokens=True
)
result = sample
else: # handling == "drop"
result = (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len
elif rl == RLType.KTO:
if not (sample.get("prompt") and sample.get("completion")):
raise ValueError("Prompt and completion keys are required for KTO datasets")
@@ -171,12 +219,86 @@ def _drop_long_sequences(
tokenizer(completion, add_special_tokens=False)["input_ids"]
)
return (len_prompt + len_completion) <= sequence_len
# Truncate first
handling_mode = sample.get("sequence_len_overflow_handling", "drop")
if handling_mode == "truncate":
# If sequence fits, return sample unchanged
if (len_prompt + len_completion) <= sequence_len:
result = sample
else:
# Calculate maximum completion length
max_completion_len = sequence_len - len_prompt
if rl is RLType.GRPO:
return True
if max_completion_len <= 0:
# Prompt itself exceeds sequence length. Cannot truncate completion to fix it.
LOG.warning(
"Prompt length (%s) exceeds sequence length (%s) for KTO sample; dropping",
len_prompt,
sequence_len,
)
result = False
else:
# Truncate the completion if needed
if len_completion > max_completion_len:
completion_tokens = tokenizer(
completion, add_special_tokens=False
)["input_ids"][:max_completion_len]
sample["completion"] = tokenizer.decode(
completion_tokens, skip_special_tokens=True
)
result = sample
else: # handling == "drop"
result = (len_prompt + len_completion) <= sequence_len
raise ValueError("Unknown RL type")
elif rl == RLType.GRPO:
# For GRPO always keep
result = True
else:
raise ValueError("Unknown RL type")
return bool(result)
def load_prepare_preference_datasets(cfg):
def _is_rl_seq_within_sequence_len(sample, rl, tokenizer, sequence_len):
"""
Boolean predicate to check whether a preference-learning sample fits within sequence_len.
Used with dataset.filter() after truncation to drop unsalvageable samples.
"""
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
):
return False
prompt = sample["prompt"]
chosen = sample["chosen"]
rejected = sample["rejected"]
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
len_rejected = len(
tokenizer(rejected, add_special_tokens=False)["input_ids"]
)
return (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len
if rl == RLType.KTO:
if not (sample.get("prompt") and sample.get("completion")):
return False
prompt = sample["prompt"]
completion = sample["completion"]
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
len_completion = len(
tokenizer(completion, add_special_tokens=False)["input_ids"]
)
return (len_prompt + len_completion) <= sequence_len
if rl == RLType.GRPO:
# GRPO does not enforce this check here
return True
return False
# Legacy shim preserved for backward compatibility; no-op in new flow
def load_split(dataset_cfgs, _cfg): # noqa: F811
return None
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:

View File

@@ -15,10 +15,12 @@ from datasets import Dataset, IterableDataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers.utils import get_dataset_lengths
from axolotl.utils.trainer import drop_long_seq
from axolotl.utils.trainer import truncate_or_drop_long_seq
LOG = get_logger(__name__)
DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING = "drop"
class RetryStrategy(Enum):
"""Enum for retry strategies."""
@@ -168,10 +170,19 @@ def drop_long_seq_in_dataset(
)
return dataset
drop_long = functools.partial(
drop_long_seq,
# Get the handling method from config, default to "drop" for backward compatibility.
# Support legacy alias "excess_token_handling" as well.
handling = cfg.get(
"sequence_len_overflow_handling",
cfg.get("excess_token_handling", DEFAULT_SEQUENCE_LEN_OVERFLOW_HANDLING),
)
# Use the function with the specified handling mode
seq_handler = functools.partial(
truncate_or_drop_long_seq,
sequence_len=sequence_len,
min_sequence_len=cfg.min_sample_len,
handling=handling,
)
with contextlib.suppress(AttributeError):
@@ -190,17 +201,31 @@ def drop_long_seq_in_dataset(
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
if handling == "truncate":
drop_long_kwargs["desc"] = "Truncating Long Sequences"
else: # handling == "drop"
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
dataset = dataset.filter(
drop_long,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset")
if handling == "truncate":
# Use map for truncate mode
dataset = dataset.map(
seq_handler,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
LOG.info(f"Truncated long samples in dataset to {sequence_len} tokens")
else: # handling == "drop"
# Use filter for drop mode
dataset = dataset.filter(
seq_handler,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset")
return dataset

View File

@@ -414,6 +414,12 @@ class AxolotlInputConfig(
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
},
)
sequence_len_overflow_handling: Literal["drop", "truncate"] = Field(
default="drop",
json_schema_extra={
"description": "How to handle sequences that overflow the sequence_len: 'drop' (remove the sample) or 'truncate' (cut off excess tokens)."
},
)
eval_sequence_len: int | None = Field(
default=None,
json_schema_extra={

View File

@@ -233,6 +233,114 @@ def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
return results
def truncate_or_drop_long_seq(
sample, sequence_len=2048, min_sequence_len=2, handling="drop"
):
"""
Either drop or truncate samples whose sequence length is either too long (> sequence_len)
or too short (< min_sequence_len).
If handling is "drop":
- Samples that are too short or too long will be dropped
If handling is "truncate":
- Samples that are too short will still be dropped
- Samples that are too long will be truncated to sequence_len
Works for both single-example (list[int]) or batched (list[list[int]]).
Returns either a boolean/list of booleans (for drop mode) or the modified sample (for truncate mode).
"""
min_sequence_len = min_sequence_len or 2
result = None
if handling == "drop":
return drop_long_seq(sample, sequence_len, min_sequence_len)
input_ids = sample["input_ids"]
# Edge case: if input_ids is empty
if not input_ids:
result = False if handling == "drop" else sample
# Single example (input_ids is a list of int)
elif isinstance(input_ids[0], int):
length = len(input_ids)
# Handle samples that are too short - always drop them
if length < min_sequence_len:
result = False if handling == "drop" else sample
# If truncation is enabled and the sample is too long, truncate it
elif length > sequence_len and handling == "truncate":
sample["input_ids"] = input_ids[:sequence_len]
# Also truncate attention_mask if present
if "attention_mask" in sample:
sample["attention_mask"] = sample["attention_mask"][:sequence_len]
# Also truncate labels if present
if "labels" in sample:
sample["labels"] = sample["labels"][:sequence_len]
# Also truncate position_ids if present
if "position_ids" in sample:
sample["position_ids"] = sample["position_ids"][:sequence_len]
# Update length if present
if "length" in sample:
sample["length"] = sequence_len
result = sample
# For drop mode or if the sample doesn't exceed max length
else:
result = (
min_sequence_len <= length <= sequence_len
if handling == "drop"
else sample
)
# Batched (input_ids is a list of lists)
else:
if handling == "drop":
results = []
for seq in input_ids:
length = len(seq)
results.append(min_sequence_len <= length <= sequence_len)
result = results
else: # truncate
# Check each sequence in the batch
for i, seq in enumerate(input_ids):
length = len(seq)
# Skip sequences that are too short
if length < min_sequence_len:
continue
# Truncate sequences that are too long
if length > sequence_len:
input_ids[i] = seq[:sequence_len]
# Also truncate attention_mask if present
if "attention_mask" in sample:
sample["attention_mask"][i] = sample["attention_mask"][i][
:sequence_len
]
# Also truncate labels if present
if "labels" in sample:
sample["labels"][i] = sample["labels"][i][:sequence_len]
# Also truncate position_ids if present
if "position_ids" in sample:
sample["position_ids"][i] = sample["position_ids"][i][
:sequence_len
]
# Update length if present
if "length" in sample:
sample["length"][i] = sequence_len
result = sample
return result
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_attn_mask = cfg.model_config_type in ["mamba", "gemma3"]
if drop_attn_mask:
@@ -368,15 +476,33 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
def process_pretraining_datasets_for_packing(
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False
train_dataset,
sequence_len,
skip_position_ids=True,
drop_attention_mask=False,
handling="drop",
):
drop_long = partial(drop_long_seq, sequence_len=sequence_len)
train_dataset = train_dataset.filter(
drop_long,
desc="Dropping Long Sequences",
load_from_cache_file=False,
# Define the function to use for handling sequences based on the mode
seq_handler_fn = partial(
truncate_or_drop_long_seq,
sequence_len=sequence_len,
handling=handling, # Pass handling mode
)
# Use map for truncate mode and filter for drop mode
if handling == "truncate":
train_dataset = train_dataset.map(
seq_handler_fn,
desc="Truncating Long Sequences",
load_from_cache_file=False,
)
else: # handling == "drop"
train_dataset = train_dataset.filter(
seq_handler_fn, # Use the same function, it returns boolean for drop mode
desc="Dropping Long Sequences",
load_from_cache_file=False,
)
if not skip_position_ids:
train_dataset = train_dataset.map(
add_position_ids,

View File

@@ -3,10 +3,12 @@ test module for the axolotl.utils.data module
"""
import unittest
from unittest.mock import MagicMock
from transformers import LlamaTokenizer
from axolotl.utils.data import encode_pretraining, md5
from axolotl.utils.data.rl import drop_long_rl_seq
from tests.hf_offline_utils import enable_hf_offline
@@ -64,5 +66,254 @@ class TestEncodePretraining(unittest.TestCase):
)
class TestDropLongRLSeq(unittest.TestCase):
"""
Tests for the drop_long_rl_seq function.
"""
def setUp(self):
# Mock tokenizer that returns length based on input string length
self.tokenizer = MagicMock()
def side_effect_func(
text, add_special_tokens=False
): # pylint: disable=unused-argument
return {"input_ids": list(range(len(text)))}
self.tokenizer.side_effect = side_effect_func
self.tokenizer.decode = lambda tokens, skip_special_tokens: "".join(
["x"] * len(tokens)
) # pylint: disable=unused-argument
self.sequence_len = 20
def test_dpo_drop_mode_valid(self):
"""Test DPO drop mode with a valid sample."""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 7,
"rejected": "r" * 6,
} # 5+7=12 <= 20, 5+6=11 <= 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertTrue(result)
def test_dpo_drop_mode_invalid_chosen(self):
"""Test DPO drop mode with chosen too long."""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 16,
"rejected": "r" * 6,
} # 5+16=21 > 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertFalse(result)
def test_dpo_drop_mode_invalid_rejected(self):
"""Test DPO drop mode with rejected too long."""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 7,
"rejected": "r" * 16,
} # 5+16=21 > 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertFalse(result)
def test_dpo_truncate_mode_no_truncation_needed(self):
"""Test DPO truncate mode when no truncation is needed."""
sample = {
"prompt": "p" * 5,
"chosen": "c" * 7,
"rejected": "r" * 6,
} # 5+7=12 <= 20, 5+6=11 <= 20
original_sample = sample.copy()
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(
result, original_sample
) # Should return the original sample unchanged
def test_dpo_truncate_mode_prompt_too_long(self):
"""Test DPO truncate mode when the prompt itself is too long."""
sample = {"prompt": "p" * 25, "chosen": "c" * 7, "rejected": "r" * 6}
original_sample = sample.copy()
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
# Even though truncation isn't possible, the function should return the original sample
# for the map operation, assuming downstream filtering will catch it.
self.assertEqual(result, original_sample)
def test_dpo_truncate_mode_chosen_truncated(self):
"""Test DPO truncate mode when only 'chosen' needs truncation."""
prompt_len = 5
max_resp_len = self.sequence_len - prompt_len # 20 - 5 = 15
sample = {
"prompt": "p" * prompt_len,
"chosen": "c" * 18,
"rejected": "r" * 10,
} # 5+18=23 > 20, 5+10=15 <= 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 15
self.assertEqual(
result["chosen"], "x" * max_resp_len
) # Check decoded truncated value
self.assertEqual(len(result["rejected"]), 10) # Unchanged
def test_dpo_truncate_mode_rejected_truncated(self):
"""Test DPO truncate mode when only 'rejected' needs truncation."""
prompt_len = 5
max_resp_len = self.sequence_len - prompt_len # 15
sample = {
"prompt": "p" * prompt_len,
"chosen": "c" * 10,
"rejected": "r" * 18,
} # 5+10=15 <= 20, 5+18=23 > 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["chosen"]), 10) # Unchanged
self.assertEqual(len(result["rejected"]), max_resp_len) # Truncated to 15
self.assertEqual(
result["rejected"], "x" * max_resp_len
) # Check decoded truncated value
def test_dpo_truncate_mode_both_truncated(self):
"""Test DPO truncate mode when both 'chosen' and 'rejected' need truncation."""
prompt_len = 8
max_resp_len = self.sequence_len - prompt_len # 20 - 8 = 12
sample = {
"prompt": "p" * prompt_len,
"chosen": "c" * 15,
"rejected": "r" * 14,
} # 8+15=23 > 20, 8+14=22 > 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 12
self.assertEqual(result["chosen"], "x" * max_resp_len)
self.assertEqual(len(result["rejected"]), max_resp_len) # Truncated to 12
self.assertEqual(result["rejected"], "x" * max_resp_len)
def test_dpo_truncate_mode_no_truncation_needed_but_long(self):
"""Test DPO truncate mode where individual parts fit but combined don't, but no truncation happens."""
# This tests the case where len(chosen) <= max_resp_len and len(rejected) <= max_resp_len
# but the initial check failed because e.g. prompt + chosen > sequence_len
# The current logic *will* truncate if len(chosen) > max_resp_len.
# Let's test a case where one is slightly too long causing the initial fail,
# but the other fits *within* the max_response_len, so only one gets truncated.
prompt_len = 10
max_resp_len = self.sequence_len - prompt_len # 10
sample = {
"prompt": "p" * prompt_len,
"chosen": "c" * 11,
"rejected": "r" * 9,
} # 10+11=21 > 20, 10+9=19 <= 20
result = drop_long_rl_seq(
sample, "dpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["chosen"]), max_resp_len) # Truncated to 10
self.assertEqual(result["chosen"], "x" * max_resp_len)
self.assertEqual(len(result["rejected"]), 9) # Unchanged, as 9 <= 10
# Add similar tests for KTO if needed, checking prompt + completion length
def test_kto_drop_mode_valid(self):
"""Test KTO drop mode with a valid sample."""
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertTrue(result)
def test_kto_drop_mode_invalid(self):
"""Test KTO drop mode with an invalid sample."""
sample = {"prompt": "p" * 5, "completion": "c" * 16} # 5+16=21 > 20
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertFalse(result)
def test_kto_truncate_mode_no_truncation_needed(self):
"""Test KTO truncate mode when no truncation is needed."""
sample = {"prompt": "p" * 5, "completion": "c" * 14} # 5+14=19 <= 20
original_sample = sample.copy()
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(result, original_sample)
def test_kto_truncate_mode_prompt_too_long(self):
"""Test KTO truncate mode when the prompt itself is too long."""
sample = {"prompt": "p" * 25, "completion": "c" * 7}
original_sample = sample.copy()
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(result, original_sample) # Returns original sample
def test_kto_truncate_mode_completion_truncated(self):
"""Test KTO truncate mode when completion needs truncation."""
prompt_len = 8
max_comp_len = self.sequence_len - prompt_len # 20 - 8 = 12
sample = {"prompt": "p" * prompt_len, "completion": "c" * 15} # 8+15=23 > 20
result = drop_long_rl_seq(
sample, "kto", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(len(result["prompt"]), prompt_len)
self.assertEqual(len(result["completion"]), max_comp_len) # Truncated to 12
self.assertEqual(result["completion"], "x" * max_comp_len)
def test_missing_keys_dpo(self):
"""Test ValueError raised if keys missing for DPO."""
sample = {"prompt": "p"}
with self.assertRaisesRegex(
ValueError, "Prompt, chosen and rejected keys are required"
):
drop_long_rl_seq(sample, "dpo", self.tokenizer, self.sequence_len)
def test_missing_keys_kto(self):
"""Test ValueError raised if keys missing for KTO."""
sample = {"prompt": "p"}
with self.assertRaisesRegex(
ValueError, "Prompt and completion keys are required"
):
drop_long_rl_seq(sample, "kto", self.tokenizer, self.sequence_len)
def test_unknown_rl_type(self):
"""Test ValueError raised for unknown RL type."""
sample = {}
with self.assertRaisesRegex(ValueError, "Unknown RL type"):
drop_long_rl_seq(sample, "xyz", self.tokenizer, self.sequence_len)
# GRPO test - current implementation always passes
def test_grpo_drop(self):
"""Test GRPO drop mode (currently always True)."""
sample = {}
result = drop_long_rl_seq(
sample, "grpo", self.tokenizer, self.sequence_len, handling="drop"
)
self.assertTrue(result)
def test_grpo_truncate(self):
"""Test GRPO truncate mode (currently returns original sample)."""
sample = {"a": 1}
result = drop_long_rl_seq(
sample, "grpo", self.tokenizer, self.sequence_len, handling="truncate"
)
self.assertEqual(result, sample)
if __name__ == "__main__":
unittest.main()

153
tests/test_trainer_utils.py Normal file
View File

@@ -0,0 +1,153 @@
"""Module containing tests for trainer utility functions."""
import unittest
from functools import partial
from axolotl.utils.trainer import truncate_or_drop_long_seq
# Test cases for truncate_or_drop_long_seq
class TestTruncateOrDropLongSeq(unittest.TestCase):
"""
Test suite for truncate_or_drop_long_seq function.
"""
def setUp(self):
# Example sequence length settings
self.sequence_len = 10
self.min_sequence_len = 3
def test_drop_mode_single(self):
"""Test drop mode with single examples."""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="drop",
)
# Too short
sample_short = {"input_ids": [1, 2]}
self.assertFalse(handler(sample_short))
# Too long
sample_long = {"input_ids": list(range(self.sequence_len + 1))}
self.assertFalse(handler(sample_long))
# Just right
sample_ok = {"input_ids": list(range(self.min_sequence_len))}
self.assertTrue(handler(sample_ok))
# Empty
sample_empty = {"input_ids": []}
self.assertFalse(handler(sample_empty))
def test_truncate_mode_single(self):
"""Test truncate mode with single examples."""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="truncate",
)
# Too short (should still be dropped implicitly by filter/map logic upstream,
# but the function itself might return the sample or False based on impl.)
# Current impl returns the original sample for map if too short, assuming upstream filters.
# Let's refine this test - the function *itself* returns the sample if too short when truncating.
sample_short = {"input_ids": [1, 2], "labels": [1, 2]}
result_short = handler(sample_short)
self.assertEqual(result_short["input_ids"], [1, 2]) # Unchanged
# Too long
original_long = list(range(self.sequence_len + 5))
sample_long = {"input_ids": list(original_long), "labels": list(original_long)}
result_long = handler(sample_long)
self.assertEqual(len(result_long["input_ids"]), self.sequence_len)
self.assertEqual(result_long["input_ids"], list(range(self.sequence_len)))
self.assertEqual(len(result_long["labels"]), self.sequence_len)
self.assertEqual(result_long["labels"], list(range(self.sequence_len)))
# Just right
sample_ok = {
"input_ids": list(range(self.min_sequence_len)),
"labels": list(range(self.min_sequence_len)),
}
result_ok = handler(sample_ok)
self.assertEqual(len(result_ok["input_ids"]), self.min_sequence_len)
self.assertEqual(result_ok, sample_ok) # Should be unchanged
# Empty
sample_empty = {"input_ids": [], "labels": []}
result_empty = handler(sample_empty)
self.assertEqual(result_empty, sample_empty) # Unchanged
def test_drop_mode_batched(self):
"""Test drop mode with batched examples."""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="drop",
)
sample = {
"input_ids": [
[1, 2], # Too short
list(range(self.sequence_len + 1)), # Too long
list(range(self.sequence_len)), # OK (len = 10)
list(range(self.min_sequence_len)), # OK (len = 3)
[], # Empty
]
}
expected = [False, False, True, True, False]
self.assertEqual(handler(sample), expected)
def test_truncate_mode_batched(self):
"""Test truncate mode with batched examples."""
handler = partial(
truncate_or_drop_long_seq,
sequence_len=self.sequence_len,
min_sequence_len=self.min_sequence_len,
handling="truncate",
)
sample = {
"input_ids": [
[1, 2], # Too short
list(range(self.sequence_len + 5)), # Too long
list(range(self.sequence_len)), # OK
list(range(self.min_sequence_len)), # OK
[], # Empty
],
"labels": [ # Add labels to test truncation
[1, 2],
list(range(self.sequence_len + 5)),
list(range(self.sequence_len)),
list(range(self.min_sequence_len)),
[],
],
}
result = handler(sample)
# Expected results after truncation (too short and empty remain unchanged by this function)
expected_input_ids = [
[1, 2], # Unchanged (too short)
list(range(self.sequence_len)), # Truncated
list(range(self.sequence_len)), # Unchanged (OK)
list(range(self.min_sequence_len)), # Unchanged (OK)
[], # Unchanged (Empty)
]
expected_labels = [
[1, 2], # Unchanged (too short)
list(range(self.sequence_len)), # Truncated
list(range(self.sequence_len)), # Unchanged (OK)
list(range(self.min_sequence_len)), # Unchanged (OK)
[], # Unchanged (Empty)
]
self.assertEqual(result["input_ids"], expected_input_ids)
self.assertEqual(result["labels"], expected_labels)
if __name__ == "__main__":
unittest.main()