From 63a58cfec116f7518b08bca8debcbe8ef1370031 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=82=86=E3=82=8A?= Date: Mon, 13 Apr 2026 08:51:10 +0800 Subject: [PATCH] feat: support excess_length_strategy for RL trainers (#3578) [skip ci] * feat: support excess_length_strategy for RL trainers Previously, RL data loading always dropped sequences exceeding sequence_len. This adds support for the existing `excess_length_strategy` config option (`drop`, `truncate`, `raise`) in RL training pipelines, matching the behavior already available for SFT. - `drop` (default): unchanged behavior, filters out long samples - `truncate`: tokenizes text components, truncates responses to fit within sequence_len while preserving the full prompt, then decodes back to text. Handles DPO/IPO/ORPO/SIMPO and KTO datasets. - `raise`: raises ValueError if any sample exceeds sequence_len Closes #3547 * improve RL truncation strategy robustness and performance --------- Co-authored-by: yurekami Co-authored-by: Wing Lian --- src/axolotl/utils/data/rl.py | 199 ++++++++++++++++++++++-- tests/utils/data/test_rl.py | 292 +++++++++++++++++++++++++++++++++++ 2 files changed, 475 insertions(+), 16 deletions(-) create mode 100644 tests/utils/data/test_rl.py diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index ef91e1124..c1d775cb4 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -180,6 +180,119 @@ def _drop_long_sequences( raise ValueError("Unknown RL type") +def _raise_on_long_sequences( + sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int +) -> bool: + """Check sequence length and raise ValueError if exceeded. + + Used as a filter function for ``excess_length_strategy: raise``. + + Args: + sample: Dataset sample to check. + rl: Reinforcement learning type. + tokenizer: Tokenizer for length calculation. + sequence_len: Maximum allowed sequence length. + + Returns: + Always True (raises before returning False). + + Raises: + ValueError: If any sample exceeds the configured sequence length. + """ + is_valid = _drop_long_sequences(sample, rl, tokenizer, sequence_len) + if not is_valid: + raise ValueError( + f"Sample exceeds configured sequence_len ({sequence_len}). " + "Set `excess_length_strategy: drop` or `excess_length_strategy: truncate` " + "to handle long sequences automatically." + ) + return True + + +def _truncate_long_sequences_rl( + sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int +) -> dict[str, Any]: + """Truncate RL samples that exceed maximum sequence length. + + For preference datasets (DPO/IPO/ORPO/SIMPO), truncates chosen and rejected + responses to fit within ``sequence_len`` when combined with the prompt. + For KTO, truncates the completion similarly. + GRPO/GDPO/EBFT samples are returned unchanged. + + Samples where the prompt alone exceeds ``sequence_len`` cannot be + meaningfully truncated and are returned unchanged. The caller should + follow up with a drop filter to remove them. + + Args: + sample: Dataset sample to potentially truncate. + rl: Reinforcement learning type. + tokenizer: Tokenizer for encoding/decoding. + sequence_len: Maximum allowed sequence length. + + Returns: + The sample with text fields truncated to fit within sequence_len. + """ + # Fast path: if sample already fits, return unchanged (avoids decode overhead) + if _drop_long_sequences(sample, rl, tokenizer, sequence_len): + return sample + + if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}: + if not ( + sample.get("prompt") and sample.get("chosen") and sample.get("rejected") + ): + raise ValueError( + "Prompt, chosen and rejected keys are required for DPO/ORPO datasets" + ) + + prompt_ids = tokenizer(sample["prompt"], add_special_tokens=False)["input_ids"] + chosen_ids = tokenizer(sample["chosen"], add_special_tokens=False)["input_ids"] + rejected_ids = tokenizer(sample["rejected"], add_special_tokens=False)[ + "input_ids" + ] + + max_response_len = sequence_len - len(prompt_ids) + if max_response_len <= 0: + # Prompt alone exceeds limit; cannot meaningfully truncate. + # Returned unchanged — the follow-up drop filter will remove it. + return sample + + updates: dict[str, Any] = {} + if len(chosen_ids) > max_response_len: + updates["chosen"] = tokenizer.decode( + chosen_ids[:max_response_len], skip_special_tokens=False + ) + if len(rejected_ids) > max_response_len: + updates["rejected"] = tokenizer.decode( + rejected_ids[:max_response_len], skip_special_tokens=False + ) + if updates: + sample = {**sample, **updates} + + elif rl is RLType.KTO: + if not (sample.get("prompt") and sample.get("completion")): + raise ValueError("Prompt and completion keys are required for KTO datasets") + + prompt_ids = tokenizer(sample["prompt"], add_special_tokens=False)["input_ids"] + completion_ids = tokenizer(sample["completion"], add_special_tokens=False)[ + "input_ids" + ] + + max_completion_len = sequence_len - len(prompt_ids) + if max_completion_len <= 0: + return sample + + if len(completion_ids) > max_completion_len: + sample = { + **sample, + "completion": tokenizer.decode( + completion_ids[:max_completion_len], skip_special_tokens=False + ), + } + + # GRPO/GDPO/EBFT: no truncation needed (responses generated at runtime) + return sample + + def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset: """Load and process dataset split for RL training. @@ -243,23 +356,77 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset: split_datasets[i] = dataset if not cfg.skip_prepare_dataset: - drop_long = partial( - _drop_long_sequences, - rl=cfg.rl, - tokenizer=tokenizer, - sequence_len=cfg.sequence_len, - ) + excess_length_strategy = (cfg.excess_length_strategy or "drop").lower() - prior_len = len(split_datasets[i]) - split_datasets[i] = split_datasets[i].filter( - drop_long, - num_proc=cfg.dataset_num_proc, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Long Sequences", - ) - dropped = prior_len - len(split_datasets[i]) - if dropped: - LOG.warning(f"Dropped {dropped} long samples from dataset index {i}") + if excess_length_strategy == "truncate": + truncate_fn = partial( + _truncate_long_sequences_rl, + rl=cfg.rl, + tokenizer=tokenizer, + sequence_len=cfg.sequence_len, + ) + prior_len = len(split_datasets[i]) + split_datasets[i] = split_datasets[i].map( + truncate_fn, + num_proc=cfg.dataset_num_proc, + load_from_cache_file=not cfg.is_preprocess, + desc="Truncating Long Sequences", + ) + + # Drop samples that could not be truncated (e.g. prompt + # alone exceeds sequence_len) + drop_long = partial( + _drop_long_sequences, + rl=cfg.rl, + tokenizer=tokenizer, + sequence_len=cfg.sequence_len, + ) + split_datasets[i] = split_datasets[i].filter( + drop_long, + num_proc=cfg.dataset_num_proc, + load_from_cache_file=not cfg.is_preprocess, + desc="Dropping Un-truncatable Sequences", + ) + dropped = prior_len - len(split_datasets[i]) + if dropped: + LOG.warning( + f"Dropped {dropped} samples from dataset index {i} " + f"that could not be truncated to fit sequence_len " + f"(prompt alone exceeds limit)" + ) + elif excess_length_strategy == "raise": + raise_fn = partial( + _raise_on_long_sequences, + rl=cfg.rl, + tokenizer=tokenizer, + sequence_len=cfg.sequence_len, + ) + split_datasets[i] = split_datasets[i].filter( + raise_fn, + num_proc=cfg.dataset_num_proc, + load_from_cache_file=not cfg.is_preprocess, + desc="Checking Sequence Lengths", + ) + else: # "drop" (default) + drop_long = partial( + _drop_long_sequences, + rl=cfg.rl, + tokenizer=tokenizer, + sequence_len=cfg.sequence_len, + ) + + prior_len = len(split_datasets[i]) + split_datasets[i] = split_datasets[i].filter( + drop_long, + num_proc=cfg.dataset_num_proc, + load_from_cache_file=not cfg.is_preprocess, + desc="Dropping Long Sequences", + ) + dropped = prior_len - len(split_datasets[i]) + if dropped: + LOG.warning( + f"Dropped {dropped} long samples from dataset index {i}" + ) # Merge datasets dataset = merge_datasets(split_datasets, cfg) diff --git a/tests/utils/data/test_rl.py b/tests/utils/data/test_rl.py new file mode 100644 index 000000000..44c6a010d --- /dev/null +++ b/tests/utils/data/test_rl.py @@ -0,0 +1,292 @@ +""" +Unit tests for RL data utility functions (excess_length_strategy support). +""" + +import unittest + +from axolotl.utils.data.rl import ( + _drop_long_sequences, + _raise_on_long_sequences, + _truncate_long_sequences_rl, +) +from axolotl.utils.schemas.enums import RLType + + +class _FakeTokenizer: + """Simple whitespace tokenizer for testing length calculations.""" + + def __call__(self, text, add_special_tokens=True): # noqa: ARG002 + tokens = text.split() + return {"input_ids": list(range(len(tokens)))} + + def decode(self, token_ids, skip_special_tokens=True): # noqa: ARG002 + # Each token id maps to a placeholder word; length is what matters. + return " ".join(f"w{i}" for i in range(len(token_ids))) + + +def _make_dpo_sample(prompt_len: int, chosen_len: int, rejected_len: int): + """Create a DPO sample with specified word counts.""" + return { + "prompt": " ".join(f"p{i}" for i in range(prompt_len)), + "chosen": " ".join(f"c{i}" for i in range(chosen_len)), + "rejected": " ".join(f"r{i}" for i in range(rejected_len)), + } + + +def _make_kto_sample(prompt_len: int, completion_len: int): + """Create a KTO sample with specified word counts.""" + return { + "prompt": " ".join(f"p{i}" for i in range(prompt_len)), + "completion": " ".join(f"c{i}" for i in range(completion_len)), + } + + +class TestDropLongSequences(unittest.TestCase): + """Tests for the existing _drop_long_sequences filter function.""" + + def setUp(self): + self.tokenizer = _FakeTokenizer() + + def test_dpo_keeps_short_samples(self): + sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2) + result = _drop_long_sequences( + sample, RLType.DPO, self.tokenizer, sequence_len=10 + ) + self.assertTrue(result) + + def test_dpo_drops_long_chosen(self): + sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=2) + result = _drop_long_sequences( + sample, RLType.DPO, self.tokenizer, sequence_len=10 + ) + self.assertFalse(result) + + def test_dpo_drops_long_rejected(self): + sample = _make_dpo_sample(prompt_len=5, chosen_len=2, rejected_len=10) + result = _drop_long_sequences( + sample, RLType.DPO, self.tokenizer, sequence_len=10 + ) + self.assertFalse(result) + + def test_kto_keeps_short_samples(self): + sample = _make_kto_sample(prompt_len=3, completion_len=2) + result = _drop_long_sequences( + sample, RLType.KTO, self.tokenizer, sequence_len=10 + ) + self.assertTrue(result) + + def test_kto_drops_long_completion(self): + sample = _make_kto_sample(prompt_len=5, completion_len=10) + result = _drop_long_sequences( + sample, RLType.KTO, self.tokenizer, sequence_len=10 + ) + self.assertFalse(result) + + def test_grpo_always_keeps(self): + sample = {"prompt": "a " * 100} + result = _drop_long_sequences( + sample, RLType.GRPO, self.tokenizer, sequence_len=5 + ) + self.assertTrue(result) + + def test_dpo_missing_keys_raises(self): + with self.assertRaises(ValueError): + _drop_long_sequences({"prompt": "hi"}, RLType.DPO, self.tokenizer, 10) + + def test_kto_missing_keys_raises(self): + with self.assertRaises(ValueError): + _drop_long_sequences({"prompt": "hi"}, RLType.KTO, self.tokenizer, 10) + + def test_ipo_uses_dpo_logic(self): + sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=2) + result = _drop_long_sequences( + sample, RLType.IPO, self.tokenizer, sequence_len=10 + ) + self.assertFalse(result) + + def test_orpo_uses_dpo_logic(self): + sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2) + result = _drop_long_sequences( + sample, RLType.ORPO, self.tokenizer, sequence_len=10 + ) + self.assertTrue(result) + + def test_boundary_length_kept(self): + """Samples exactly at sequence_len should be kept.""" + sample = _make_dpo_sample(prompt_len=5, chosen_len=5, rejected_len=5) + result = _drop_long_sequences( + sample, RLType.DPO, self.tokenizer, sequence_len=10 + ) + self.assertTrue(result) + + +class TestRaiseOnLongSequences(unittest.TestCase): + """Tests for _raise_on_long_sequences (excess_length_strategy='raise').""" + + def setUp(self): + self.tokenizer = _FakeTokenizer() + + def test_short_sample_passes(self): + sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2) + result = _raise_on_long_sequences( + sample, RLType.DPO, self.tokenizer, sequence_len=10 + ) + self.assertTrue(result) + + def test_long_sample_raises_valueerror(self): + sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=2) + with self.assertRaises(ValueError, msg="excess_length_strategy"): + _raise_on_long_sequences( + sample, RLType.DPO, self.tokenizer, sequence_len=10 + ) + + def test_kto_long_raises(self): + sample = _make_kto_sample(prompt_len=5, completion_len=10) + with self.assertRaises(ValueError): + _raise_on_long_sequences( + sample, RLType.KTO, self.tokenizer, sequence_len=10 + ) + + def test_grpo_never_raises(self): + sample = {"prompt": "a " * 100} + result = _raise_on_long_sequences( + sample, RLType.GRPO, self.tokenizer, sequence_len=5 + ) + self.assertTrue(result) + + +class TestTruncateLongSequencesRL(unittest.TestCase): + """Tests for _truncate_long_sequences_rl (excess_length_strategy='truncate').""" + + def setUp(self): + self.tokenizer = _FakeTokenizer() + + def test_dpo_short_sample_unchanged(self): + sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2) + result = _truncate_long_sequences_rl( + sample, RLType.DPO, self.tokenizer, sequence_len=10 + ) + self.assertEqual(result["chosen"], sample["chosen"]) + self.assertEqual(result["rejected"], sample["rejected"]) + + def test_dpo_truncates_chosen(self): + sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=3) + result = _truncate_long_sequences_rl( + sample, RLType.DPO, self.tokenizer, sequence_len=10 + ) + # max_response_len = 10 - 5 = 5, chosen had 10 words -> truncated to 5 + chosen_tokens = self.tokenizer(result["chosen"], add_special_tokens=False)[ + "input_ids" + ] + self.assertEqual(len(chosen_tokens), 5) + + def test_dpo_truncates_rejected(self): + sample = _make_dpo_sample(prompt_len=5, chosen_len=3, rejected_len=10) + result = _truncate_long_sequences_rl( + sample, RLType.DPO, self.tokenizer, sequence_len=10 + ) + rejected_tokens = self.tokenizer(result["rejected"], add_special_tokens=False)[ + "input_ids" + ] + self.assertEqual(len(rejected_tokens), 5) + + def test_dpo_truncates_both(self): + sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=10) + result = _truncate_long_sequences_rl( + sample, RLType.DPO, self.tokenizer, sequence_len=10 + ) + chosen_len = len( + self.tokenizer(result["chosen"], add_special_tokens=False)["input_ids"] + ) + rejected_len = len( + self.tokenizer(result["rejected"], add_special_tokens=False)["input_ids"] + ) + self.assertEqual(chosen_len, 5) + self.assertEqual(rejected_len, 5) + + def test_dpo_prompt_unchanged(self): + """Prompt text should never be modified.""" + sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=10) + result = _truncate_long_sequences_rl( + sample, RLType.DPO, self.tokenizer, sequence_len=10 + ) + self.assertEqual(result["prompt"], sample["prompt"]) + + def test_dpo_prompt_exceeds_limit_returns_unchanged(self): + """When prompt alone exceeds sequence_len, sample is returned as-is.""" + sample = _make_dpo_sample(prompt_len=15, chosen_len=3, rejected_len=3) + result = _truncate_long_sequences_rl( + sample, RLType.DPO, self.tokenizer, sequence_len=10 + ) + self.assertEqual(result, sample) + + def test_kto_truncates_completion(self): + sample = _make_kto_sample(prompt_len=5, completion_len=10) + result = _truncate_long_sequences_rl( + sample, RLType.KTO, self.tokenizer, sequence_len=10 + ) + completion_len = len( + self.tokenizer(result["completion"], add_special_tokens=False)["input_ids"] + ) + self.assertEqual(completion_len, 5) + + def test_kto_short_sample_unchanged(self): + sample = _make_kto_sample(prompt_len=3, completion_len=2) + result = _truncate_long_sequences_rl( + sample, RLType.KTO, self.tokenizer, sequence_len=10 + ) + self.assertEqual(result["completion"], sample["completion"]) + + def test_kto_prompt_exceeds_limit_returns_unchanged(self): + sample = _make_kto_sample(prompt_len=15, completion_len=3) + result = _truncate_long_sequences_rl( + sample, RLType.KTO, self.tokenizer, sequence_len=10 + ) + self.assertEqual(result, sample) + + def test_grpo_unchanged(self): + sample = {"prompt": "a " * 100} + result = _truncate_long_sequences_rl( + sample, RLType.GRPO, self.tokenizer, sequence_len=5 + ) + self.assertEqual(result, sample) + + def test_ipo_uses_dpo_logic(self): + sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=3) + result = _truncate_long_sequences_rl( + sample, RLType.IPO, self.tokenizer, sequence_len=10 + ) + chosen_len = len( + self.tokenizer(result["chosen"], add_special_tokens=False)["input_ids"] + ) + self.assertEqual(chosen_len, 5) + + def test_does_not_mutate_original(self): + """Verify immutability — original sample dict is not modified.""" + sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=10) + original_chosen = sample["chosen"] + original_rejected = sample["rejected"] + _truncate_long_sequences_rl(sample, RLType.DPO, self.tokenizer, sequence_len=10) + self.assertEqual(sample["chosen"], original_chosen) + self.assertEqual(sample["rejected"], original_rejected) + + def test_dpo_missing_keys_raises(self): + with self.assertRaises(ValueError): + _truncate_long_sequences_rl( + {"prompt": "hi"}, RLType.DPO, self.tokenizer, 10 + ) + + def test_kto_missing_keys_raises(self): + with self.assertRaises(ValueError): + _truncate_long_sequences_rl( + {"prompt": "hi"}, RLType.KTO, self.tokenizer, 10 + ) + + def test_boundary_no_truncation_needed(self): + """Samples exactly at sequence_len should not be modified.""" + sample = _make_dpo_sample(prompt_len=5, chosen_len=5, rejected_len=5) + result = _truncate_long_sequences_rl( + sample, RLType.DPO, self.tokenizer, sequence_len=10 + ) + self.assertEqual(result["chosen"], sample["chosen"]) + self.assertEqual(result["rejected"], sample["rejected"])