* feat: support excess_length_strategy for RL trainers Previously, RL data loading always dropped sequences exceeding sequence_len. This adds support for the existing `excess_length_strategy` config option (`drop`, `truncate`, `raise`) in RL training pipelines, matching the behavior already available for SFT. - `drop` (default): unchanged behavior, filters out long samples - `truncate`: tokenizes text components, truncates responses to fit within sequence_len while preserving the full prompt, then decodes back to text. Handles DPO/IPO/ORPO/SIMPO and KTO datasets. - `raise`: raises ValueError if any sample exceeds sequence_len Closes #3547 * improve RL truncation strategy robustness and performance --------- Co-authored-by: yurekami <yurekami@users.noreply.github.com> Co-authored-by: Wing Lian <wing@axolotl.ai>
293 lines
11 KiB
Python
293 lines
11 KiB
Python
"""
|
|
Unit tests for RL data utility functions (excess_length_strategy support).
|
|
"""
|
|
|
|
import unittest
|
|
|
|
from axolotl.utils.data.rl import (
|
|
_drop_long_sequences,
|
|
_raise_on_long_sequences,
|
|
_truncate_long_sequences_rl,
|
|
)
|
|
from axolotl.utils.schemas.enums import RLType
|
|
|
|
|
|
class _FakeTokenizer:
|
|
"""Simple whitespace tokenizer for testing length calculations."""
|
|
|
|
def __call__(self, text, add_special_tokens=True): # noqa: ARG002
|
|
tokens = text.split()
|
|
return {"input_ids": list(range(len(tokens)))}
|
|
|
|
def decode(self, token_ids, skip_special_tokens=True): # noqa: ARG002
|
|
# Each token id maps to a placeholder word; length is what matters.
|
|
return " ".join(f"w{i}" for i in range(len(token_ids)))
|
|
|
|
|
|
def _make_dpo_sample(prompt_len: int, chosen_len: int, rejected_len: int):
|
|
"""Create a DPO sample with specified word counts."""
|
|
return {
|
|
"prompt": " ".join(f"p{i}" for i in range(prompt_len)),
|
|
"chosen": " ".join(f"c{i}" for i in range(chosen_len)),
|
|
"rejected": " ".join(f"r{i}" for i in range(rejected_len)),
|
|
}
|
|
|
|
|
|
def _make_kto_sample(prompt_len: int, completion_len: int):
|
|
"""Create a KTO sample with specified word counts."""
|
|
return {
|
|
"prompt": " ".join(f"p{i}" for i in range(prompt_len)),
|
|
"completion": " ".join(f"c{i}" for i in range(completion_len)),
|
|
}
|
|
|
|
|
|
class TestDropLongSequences(unittest.TestCase):
|
|
"""Tests for the existing _drop_long_sequences filter function."""
|
|
|
|
def setUp(self):
|
|
self.tokenizer = _FakeTokenizer()
|
|
|
|
def test_dpo_keeps_short_samples(self):
|
|
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
|
|
result = _drop_long_sequences(
|
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
|
)
|
|
self.assertTrue(result)
|
|
|
|
def test_dpo_drops_long_chosen(self):
|
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=2)
|
|
result = _drop_long_sequences(
|
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
|
)
|
|
self.assertFalse(result)
|
|
|
|
def test_dpo_drops_long_rejected(self):
|
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=2, rejected_len=10)
|
|
result = _drop_long_sequences(
|
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
|
)
|
|
self.assertFalse(result)
|
|
|
|
def test_kto_keeps_short_samples(self):
|
|
sample = _make_kto_sample(prompt_len=3, completion_len=2)
|
|
result = _drop_long_sequences(
|
|
sample, RLType.KTO, self.tokenizer, sequence_len=10
|
|
)
|
|
self.assertTrue(result)
|
|
|
|
def test_kto_drops_long_completion(self):
|
|
sample = _make_kto_sample(prompt_len=5, completion_len=10)
|
|
result = _drop_long_sequences(
|
|
sample, RLType.KTO, self.tokenizer, sequence_len=10
|
|
)
|
|
self.assertFalse(result)
|
|
|
|
def test_grpo_always_keeps(self):
|
|
sample = {"prompt": "a " * 100}
|
|
result = _drop_long_sequences(
|
|
sample, RLType.GRPO, self.tokenizer, sequence_len=5
|
|
)
|
|
self.assertTrue(result)
|
|
|
|
def test_dpo_missing_keys_raises(self):
|
|
with self.assertRaises(ValueError):
|
|
_drop_long_sequences({"prompt": "hi"}, RLType.DPO, self.tokenizer, 10)
|
|
|
|
def test_kto_missing_keys_raises(self):
|
|
with self.assertRaises(ValueError):
|
|
_drop_long_sequences({"prompt": "hi"}, RLType.KTO, self.tokenizer, 10)
|
|
|
|
def test_ipo_uses_dpo_logic(self):
|
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=2)
|
|
result = _drop_long_sequences(
|
|
sample, RLType.IPO, self.tokenizer, sequence_len=10
|
|
)
|
|
self.assertFalse(result)
|
|
|
|
def test_orpo_uses_dpo_logic(self):
|
|
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
|
|
result = _drop_long_sequences(
|
|
sample, RLType.ORPO, self.tokenizer, sequence_len=10
|
|
)
|
|
self.assertTrue(result)
|
|
|
|
def test_boundary_length_kept(self):
|
|
"""Samples exactly at sequence_len should be kept."""
|
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=5, rejected_len=5)
|
|
result = _drop_long_sequences(
|
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
|
)
|
|
self.assertTrue(result)
|
|
|
|
|
|
class TestRaiseOnLongSequences(unittest.TestCase):
|
|
"""Tests for _raise_on_long_sequences (excess_length_strategy='raise')."""
|
|
|
|
def setUp(self):
|
|
self.tokenizer = _FakeTokenizer()
|
|
|
|
def test_short_sample_passes(self):
|
|
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
|
|
result = _raise_on_long_sequences(
|
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
|
)
|
|
self.assertTrue(result)
|
|
|
|
def test_long_sample_raises_valueerror(self):
|
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=2)
|
|
with self.assertRaises(ValueError, msg="excess_length_strategy"):
|
|
_raise_on_long_sequences(
|
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
|
)
|
|
|
|
def test_kto_long_raises(self):
|
|
sample = _make_kto_sample(prompt_len=5, completion_len=10)
|
|
with self.assertRaises(ValueError):
|
|
_raise_on_long_sequences(
|
|
sample, RLType.KTO, self.tokenizer, sequence_len=10
|
|
)
|
|
|
|
def test_grpo_never_raises(self):
|
|
sample = {"prompt": "a " * 100}
|
|
result = _raise_on_long_sequences(
|
|
sample, RLType.GRPO, self.tokenizer, sequence_len=5
|
|
)
|
|
self.assertTrue(result)
|
|
|
|
|
|
class TestTruncateLongSequencesRL(unittest.TestCase):
|
|
"""Tests for _truncate_long_sequences_rl (excess_length_strategy='truncate')."""
|
|
|
|
def setUp(self):
|
|
self.tokenizer = _FakeTokenizer()
|
|
|
|
def test_dpo_short_sample_unchanged(self):
|
|
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
|
|
result = _truncate_long_sequences_rl(
|
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
|
)
|
|
self.assertEqual(result["chosen"], sample["chosen"])
|
|
self.assertEqual(result["rejected"], sample["rejected"])
|
|
|
|
def test_dpo_truncates_chosen(self):
|
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=3)
|
|
result = _truncate_long_sequences_rl(
|
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
|
)
|
|
# max_response_len = 10 - 5 = 5, chosen had 10 words -> truncated to 5
|
|
chosen_tokens = self.tokenizer(result["chosen"], add_special_tokens=False)[
|
|
"input_ids"
|
|
]
|
|
self.assertEqual(len(chosen_tokens), 5)
|
|
|
|
def test_dpo_truncates_rejected(self):
|
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=3, rejected_len=10)
|
|
result = _truncate_long_sequences_rl(
|
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
|
)
|
|
rejected_tokens = self.tokenizer(result["rejected"], add_special_tokens=False)[
|
|
"input_ids"
|
|
]
|
|
self.assertEqual(len(rejected_tokens), 5)
|
|
|
|
def test_dpo_truncates_both(self):
|
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=10)
|
|
result = _truncate_long_sequences_rl(
|
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
|
)
|
|
chosen_len = len(
|
|
self.tokenizer(result["chosen"], add_special_tokens=False)["input_ids"]
|
|
)
|
|
rejected_len = len(
|
|
self.tokenizer(result["rejected"], add_special_tokens=False)["input_ids"]
|
|
)
|
|
self.assertEqual(chosen_len, 5)
|
|
self.assertEqual(rejected_len, 5)
|
|
|
|
def test_dpo_prompt_unchanged(self):
|
|
"""Prompt text should never be modified."""
|
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=10)
|
|
result = _truncate_long_sequences_rl(
|
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
|
)
|
|
self.assertEqual(result["prompt"], sample["prompt"])
|
|
|
|
def test_dpo_prompt_exceeds_limit_returns_unchanged(self):
|
|
"""When prompt alone exceeds sequence_len, sample is returned as-is."""
|
|
sample = _make_dpo_sample(prompt_len=15, chosen_len=3, rejected_len=3)
|
|
result = _truncate_long_sequences_rl(
|
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
|
)
|
|
self.assertEqual(result, sample)
|
|
|
|
def test_kto_truncates_completion(self):
|
|
sample = _make_kto_sample(prompt_len=5, completion_len=10)
|
|
result = _truncate_long_sequences_rl(
|
|
sample, RLType.KTO, self.tokenizer, sequence_len=10
|
|
)
|
|
completion_len = len(
|
|
self.tokenizer(result["completion"], add_special_tokens=False)["input_ids"]
|
|
)
|
|
self.assertEqual(completion_len, 5)
|
|
|
|
def test_kto_short_sample_unchanged(self):
|
|
sample = _make_kto_sample(prompt_len=3, completion_len=2)
|
|
result = _truncate_long_sequences_rl(
|
|
sample, RLType.KTO, self.tokenizer, sequence_len=10
|
|
)
|
|
self.assertEqual(result["completion"], sample["completion"])
|
|
|
|
def test_kto_prompt_exceeds_limit_returns_unchanged(self):
|
|
sample = _make_kto_sample(prompt_len=15, completion_len=3)
|
|
result = _truncate_long_sequences_rl(
|
|
sample, RLType.KTO, self.tokenizer, sequence_len=10
|
|
)
|
|
self.assertEqual(result, sample)
|
|
|
|
def test_grpo_unchanged(self):
|
|
sample = {"prompt": "a " * 100}
|
|
result = _truncate_long_sequences_rl(
|
|
sample, RLType.GRPO, self.tokenizer, sequence_len=5
|
|
)
|
|
self.assertEqual(result, sample)
|
|
|
|
def test_ipo_uses_dpo_logic(self):
|
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=3)
|
|
result = _truncate_long_sequences_rl(
|
|
sample, RLType.IPO, self.tokenizer, sequence_len=10
|
|
)
|
|
chosen_len = len(
|
|
self.tokenizer(result["chosen"], add_special_tokens=False)["input_ids"]
|
|
)
|
|
self.assertEqual(chosen_len, 5)
|
|
|
|
def test_does_not_mutate_original(self):
|
|
"""Verify immutability — original sample dict is not modified."""
|
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=10)
|
|
original_chosen = sample["chosen"]
|
|
original_rejected = sample["rejected"]
|
|
_truncate_long_sequences_rl(sample, RLType.DPO, self.tokenizer, sequence_len=10)
|
|
self.assertEqual(sample["chosen"], original_chosen)
|
|
self.assertEqual(sample["rejected"], original_rejected)
|
|
|
|
def test_dpo_missing_keys_raises(self):
|
|
with self.assertRaises(ValueError):
|
|
_truncate_long_sequences_rl(
|
|
{"prompt": "hi"}, RLType.DPO, self.tokenizer, 10
|
|
)
|
|
|
|
def test_kto_missing_keys_raises(self):
|
|
with self.assertRaises(ValueError):
|
|
_truncate_long_sequences_rl(
|
|
{"prompt": "hi"}, RLType.KTO, self.tokenizer, 10
|
|
)
|
|
|
|
def test_boundary_no_truncation_needed(self):
|
|
"""Samples exactly at sequence_len should not be modified."""
|
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=5, rejected_len=5)
|
|
result = _truncate_long_sequences_rl(
|
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
|
)
|
|
self.assertEqual(result["chosen"], sample["chosen"])
|
|
self.assertEqual(result["rejected"], sample["rejected"])
|