feat: support excess_length_strategy for RL trainers (#3578) [skip ci]
* feat: support excess_length_strategy for RL trainers Previously, RL data loading always dropped sequences exceeding sequence_len. This adds support for the existing `excess_length_strategy` config option (`drop`, `truncate`, `raise`) in RL training pipelines, matching the behavior already available for SFT. - `drop` (default): unchanged behavior, filters out long samples - `truncate`: tokenizes text components, truncates responses to fit within sequence_len while preserving the full prompt, then decodes back to text. Handles DPO/IPO/ORPO/SIMPO and KTO datasets. - `raise`: raises ValueError if any sample exceeds sequence_len Closes #3547 * improve RL truncation strategy robustness and performance --------- Co-authored-by: yurekami <yurekami@users.noreply.github.com> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -180,6 +180,119 @@ def _drop_long_sequences(
|
|||||||
raise ValueError("Unknown RL type")
|
raise ValueError("Unknown RL type")
|
||||||
|
|
||||||
|
|
||||||
|
def _raise_on_long_sequences(
|
||||||
|
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
|
||||||
|
) -> bool:
|
||||||
|
"""Check sequence length and raise ValueError if exceeded.
|
||||||
|
|
||||||
|
Used as a filter function for ``excess_length_strategy: raise``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample: Dataset sample to check.
|
||||||
|
rl: Reinforcement learning type.
|
||||||
|
tokenizer: Tokenizer for length calculation.
|
||||||
|
sequence_len: Maximum allowed sequence length.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Always True (raises before returning False).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If any sample exceeds the configured sequence length.
|
||||||
|
"""
|
||||||
|
is_valid = _drop_long_sequences(sample, rl, tokenizer, sequence_len)
|
||||||
|
if not is_valid:
|
||||||
|
raise ValueError(
|
||||||
|
f"Sample exceeds configured sequence_len ({sequence_len}). "
|
||||||
|
"Set `excess_length_strategy: drop` or `excess_length_strategy: truncate` "
|
||||||
|
"to handle long sequences automatically."
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate_long_sequences_rl(
|
||||||
|
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Truncate RL samples that exceed maximum sequence length.
|
||||||
|
|
||||||
|
For preference datasets (DPO/IPO/ORPO/SIMPO), truncates chosen and rejected
|
||||||
|
responses to fit within ``sequence_len`` when combined with the prompt.
|
||||||
|
For KTO, truncates the completion similarly.
|
||||||
|
GRPO/GDPO/EBFT samples are returned unchanged.
|
||||||
|
|
||||||
|
Samples where the prompt alone exceeds ``sequence_len`` cannot be
|
||||||
|
meaningfully truncated and are returned unchanged. The caller should
|
||||||
|
follow up with a drop filter to remove them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample: Dataset sample to potentially truncate.
|
||||||
|
rl: Reinforcement learning type.
|
||||||
|
tokenizer: Tokenizer for encoding/decoding.
|
||||||
|
sequence_len: Maximum allowed sequence length.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The sample with text fields truncated to fit within sequence_len.
|
||||||
|
"""
|
||||||
|
# Fast path: if sample already fits, return unchanged (avoids decode overhead)
|
||||||
|
if _drop_long_sequences(sample, rl, tokenizer, sequence_len):
|
||||||
|
return sample
|
||||||
|
|
||||||
|
if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}:
|
||||||
|
if not (
|
||||||
|
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Prompt, chosen and rejected keys are required for DPO/ORPO datasets"
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_ids = tokenizer(sample["prompt"], add_special_tokens=False)["input_ids"]
|
||||||
|
chosen_ids = tokenizer(sample["chosen"], add_special_tokens=False)["input_ids"]
|
||||||
|
rejected_ids = tokenizer(sample["rejected"], add_special_tokens=False)[
|
||||||
|
"input_ids"
|
||||||
|
]
|
||||||
|
|
||||||
|
max_response_len = sequence_len - len(prompt_ids)
|
||||||
|
if max_response_len <= 0:
|
||||||
|
# Prompt alone exceeds limit; cannot meaningfully truncate.
|
||||||
|
# Returned unchanged — the follow-up drop filter will remove it.
|
||||||
|
return sample
|
||||||
|
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if len(chosen_ids) > max_response_len:
|
||||||
|
updates["chosen"] = tokenizer.decode(
|
||||||
|
chosen_ids[:max_response_len], skip_special_tokens=False
|
||||||
|
)
|
||||||
|
if len(rejected_ids) > max_response_len:
|
||||||
|
updates["rejected"] = tokenizer.decode(
|
||||||
|
rejected_ids[:max_response_len], skip_special_tokens=False
|
||||||
|
)
|
||||||
|
if updates:
|
||||||
|
sample = {**sample, **updates}
|
||||||
|
|
||||||
|
elif rl is RLType.KTO:
|
||||||
|
if not (sample.get("prompt") and sample.get("completion")):
|
||||||
|
raise ValueError("Prompt and completion keys are required for KTO datasets")
|
||||||
|
|
||||||
|
prompt_ids = tokenizer(sample["prompt"], add_special_tokens=False)["input_ids"]
|
||||||
|
completion_ids = tokenizer(sample["completion"], add_special_tokens=False)[
|
||||||
|
"input_ids"
|
||||||
|
]
|
||||||
|
|
||||||
|
max_completion_len = sequence_len - len(prompt_ids)
|
||||||
|
if max_completion_len <= 0:
|
||||||
|
return sample
|
||||||
|
|
||||||
|
if len(completion_ids) > max_completion_len:
|
||||||
|
sample = {
|
||||||
|
**sample,
|
||||||
|
"completion": tokenizer.decode(
|
||||||
|
completion_ids[:max_completion_len], skip_special_tokens=False
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# GRPO/GDPO/EBFT: no truncation needed (responses generated at runtime)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||||
"""Load and process dataset split for RL training.
|
"""Load and process dataset split for RL training.
|
||||||
|
|
||||||
@@ -243,23 +356,77 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
|||||||
split_datasets[i] = dataset
|
split_datasets[i] = dataset
|
||||||
|
|
||||||
if not cfg.skip_prepare_dataset:
|
if not cfg.skip_prepare_dataset:
|
||||||
drop_long = partial(
|
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
||||||
_drop_long_sequences,
|
|
||||||
rl=cfg.rl,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
sequence_len=cfg.sequence_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
prior_len = len(split_datasets[i])
|
if excess_length_strategy == "truncate":
|
||||||
split_datasets[i] = split_datasets[i].filter(
|
truncate_fn = partial(
|
||||||
drop_long,
|
_truncate_long_sequences_rl,
|
||||||
num_proc=cfg.dataset_num_proc,
|
rl=cfg.rl,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
tokenizer=tokenizer,
|
||||||
desc="Dropping Long Sequences",
|
sequence_len=cfg.sequence_len,
|
||||||
)
|
)
|
||||||
dropped = prior_len - len(split_datasets[i])
|
prior_len = len(split_datasets[i])
|
||||||
if dropped:
|
split_datasets[i] = split_datasets[i].map(
|
||||||
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
|
truncate_fn,
|
||||||
|
num_proc=cfg.dataset_num_proc,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
|
desc="Truncating Long Sequences",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop samples that could not be truncated (e.g. prompt
|
||||||
|
# alone exceeds sequence_len)
|
||||||
|
drop_long = partial(
|
||||||
|
_drop_long_sequences,
|
||||||
|
rl=cfg.rl,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
sequence_len=cfg.sequence_len,
|
||||||
|
)
|
||||||
|
split_datasets[i] = split_datasets[i].filter(
|
||||||
|
drop_long,
|
||||||
|
num_proc=cfg.dataset_num_proc,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
|
desc="Dropping Un-truncatable Sequences",
|
||||||
|
)
|
||||||
|
dropped = prior_len - len(split_datasets[i])
|
||||||
|
if dropped:
|
||||||
|
LOG.warning(
|
||||||
|
f"Dropped {dropped} samples from dataset index {i} "
|
||||||
|
f"that could not be truncated to fit sequence_len "
|
||||||
|
f"(prompt alone exceeds limit)"
|
||||||
|
)
|
||||||
|
elif excess_length_strategy == "raise":
|
||||||
|
raise_fn = partial(
|
||||||
|
_raise_on_long_sequences,
|
||||||
|
rl=cfg.rl,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
sequence_len=cfg.sequence_len,
|
||||||
|
)
|
||||||
|
split_datasets[i] = split_datasets[i].filter(
|
||||||
|
raise_fn,
|
||||||
|
num_proc=cfg.dataset_num_proc,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
|
desc="Checking Sequence Lengths",
|
||||||
|
)
|
||||||
|
else: # "drop" (default)
|
||||||
|
drop_long = partial(
|
||||||
|
_drop_long_sequences,
|
||||||
|
rl=cfg.rl,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
sequence_len=cfg.sequence_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
prior_len = len(split_datasets[i])
|
||||||
|
split_datasets[i] = split_datasets[i].filter(
|
||||||
|
drop_long,
|
||||||
|
num_proc=cfg.dataset_num_proc,
|
||||||
|
load_from_cache_file=not cfg.is_preprocess,
|
||||||
|
desc="Dropping Long Sequences",
|
||||||
|
)
|
||||||
|
dropped = prior_len - len(split_datasets[i])
|
||||||
|
if dropped:
|
||||||
|
LOG.warning(
|
||||||
|
f"Dropped {dropped} long samples from dataset index {i}"
|
||||||
|
)
|
||||||
|
|
||||||
# Merge datasets
|
# Merge datasets
|
||||||
dataset = merge_datasets(split_datasets, cfg)
|
dataset = merge_datasets(split_datasets, cfg)
|
||||||
|
|||||||
292
tests/utils/data/test_rl.py
Normal file
292
tests/utils/data/test_rl.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for RL data utility functions (excess_length_strategy support).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from axolotl.utils.data.rl import (
|
||||||
|
_drop_long_sequences,
|
||||||
|
_raise_on_long_sequences,
|
||||||
|
_truncate_long_sequences_rl,
|
||||||
|
)
|
||||||
|
from axolotl.utils.schemas.enums import RLType
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeTokenizer:
|
||||||
|
"""Simple whitespace tokenizer for testing length calculations."""
|
||||||
|
|
||||||
|
def __call__(self, text, add_special_tokens=True): # noqa: ARG002
|
||||||
|
tokens = text.split()
|
||||||
|
return {"input_ids": list(range(len(tokens)))}
|
||||||
|
|
||||||
|
def decode(self, token_ids, skip_special_tokens=True): # noqa: ARG002
|
||||||
|
# Each token id maps to a placeholder word; length is what matters.
|
||||||
|
return " ".join(f"w{i}" for i in range(len(token_ids)))
|
||||||
|
|
||||||
|
|
||||||
|
def _make_dpo_sample(prompt_len: int, chosen_len: int, rejected_len: int):
|
||||||
|
"""Create a DPO sample with specified word counts."""
|
||||||
|
return {
|
||||||
|
"prompt": " ".join(f"p{i}" for i in range(prompt_len)),
|
||||||
|
"chosen": " ".join(f"c{i}" for i in range(chosen_len)),
|
||||||
|
"rejected": " ".join(f"r{i}" for i in range(rejected_len)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_kto_sample(prompt_len: int, completion_len: int):
|
||||||
|
"""Create a KTO sample with specified word counts."""
|
||||||
|
return {
|
||||||
|
"prompt": " ".join(f"p{i}" for i in range(prompt_len)),
|
||||||
|
"completion": " ".join(f"c{i}" for i in range(completion_len)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestDropLongSequences(unittest.TestCase):
|
||||||
|
"""Tests for the existing _drop_long_sequences filter function."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tokenizer = _FakeTokenizer()
|
||||||
|
|
||||||
|
def test_dpo_keeps_short_samples(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
|
||||||
|
result = _drop_long_sequences(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
def test_dpo_drops_long_chosen(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=2)
|
||||||
|
result = _drop_long_sequences(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
def test_dpo_drops_long_rejected(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=2, rejected_len=10)
|
||||||
|
result = _drop_long_sequences(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
def test_kto_keeps_short_samples(self):
|
||||||
|
sample = _make_kto_sample(prompt_len=3, completion_len=2)
|
||||||
|
result = _drop_long_sequences(
|
||||||
|
sample, RLType.KTO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
def test_kto_drops_long_completion(self):
|
||||||
|
sample = _make_kto_sample(prompt_len=5, completion_len=10)
|
||||||
|
result = _drop_long_sequences(
|
||||||
|
sample, RLType.KTO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
def test_grpo_always_keeps(self):
|
||||||
|
sample = {"prompt": "a " * 100}
|
||||||
|
result = _drop_long_sequences(
|
||||||
|
sample, RLType.GRPO, self.tokenizer, sequence_len=5
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
def test_dpo_missing_keys_raises(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_drop_long_sequences({"prompt": "hi"}, RLType.DPO, self.tokenizer, 10)
|
||||||
|
|
||||||
|
def test_kto_missing_keys_raises(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_drop_long_sequences({"prompt": "hi"}, RLType.KTO, self.tokenizer, 10)
|
||||||
|
|
||||||
|
def test_ipo_uses_dpo_logic(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=2)
|
||||||
|
result = _drop_long_sequences(
|
||||||
|
sample, RLType.IPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertFalse(result)
|
||||||
|
|
||||||
|
def test_orpo_uses_dpo_logic(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
|
||||||
|
result = _drop_long_sequences(
|
||||||
|
sample, RLType.ORPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
def test_boundary_length_kept(self):
|
||||||
|
"""Samples exactly at sequence_len should be kept."""
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=5, rejected_len=5)
|
||||||
|
result = _drop_long_sequences(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRaiseOnLongSequences(unittest.TestCase):
|
||||||
|
"""Tests for _raise_on_long_sequences (excess_length_strategy='raise')."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tokenizer = _FakeTokenizer()
|
||||||
|
|
||||||
|
def test_short_sample_passes(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
|
||||||
|
result = _raise_on_long_sequences(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
def test_long_sample_raises_valueerror(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=2)
|
||||||
|
with self.assertRaises(ValueError, msg="excess_length_strategy"):
|
||||||
|
_raise_on_long_sequences(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_kto_long_raises(self):
|
||||||
|
sample = _make_kto_sample(prompt_len=5, completion_len=10)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_raise_on_long_sequences(
|
||||||
|
sample, RLType.KTO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_grpo_never_raises(self):
|
||||||
|
sample = {"prompt": "a " * 100}
|
||||||
|
result = _raise_on_long_sequences(
|
||||||
|
sample, RLType.GRPO, self.tokenizer, sequence_len=5
|
||||||
|
)
|
||||||
|
self.assertTrue(result)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTruncateLongSequencesRL(unittest.TestCase):
|
||||||
|
"""Tests for _truncate_long_sequences_rl (excess_length_strategy='truncate')."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tokenizer = _FakeTokenizer()
|
||||||
|
|
||||||
|
def test_dpo_short_sample_unchanged(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=3, chosen_len=2, rejected_len=2)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertEqual(result["chosen"], sample["chosen"])
|
||||||
|
self.assertEqual(result["rejected"], sample["rejected"])
|
||||||
|
|
||||||
|
def test_dpo_truncates_chosen(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=3)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
# max_response_len = 10 - 5 = 5, chosen had 10 words -> truncated to 5
|
||||||
|
chosen_tokens = self.tokenizer(result["chosen"], add_special_tokens=False)[
|
||||||
|
"input_ids"
|
||||||
|
]
|
||||||
|
self.assertEqual(len(chosen_tokens), 5)
|
||||||
|
|
||||||
|
def test_dpo_truncates_rejected(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=3, rejected_len=10)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
rejected_tokens = self.tokenizer(result["rejected"], add_special_tokens=False)[
|
||||||
|
"input_ids"
|
||||||
|
]
|
||||||
|
self.assertEqual(len(rejected_tokens), 5)
|
||||||
|
|
||||||
|
def test_dpo_truncates_both(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=10)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
chosen_len = len(
|
||||||
|
self.tokenizer(result["chosen"], add_special_tokens=False)["input_ids"]
|
||||||
|
)
|
||||||
|
rejected_len = len(
|
||||||
|
self.tokenizer(result["rejected"], add_special_tokens=False)["input_ids"]
|
||||||
|
)
|
||||||
|
self.assertEqual(chosen_len, 5)
|
||||||
|
self.assertEqual(rejected_len, 5)
|
||||||
|
|
||||||
|
def test_dpo_prompt_unchanged(self):
|
||||||
|
"""Prompt text should never be modified."""
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=10)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertEqual(result["prompt"], sample["prompt"])
|
||||||
|
|
||||||
|
def test_dpo_prompt_exceeds_limit_returns_unchanged(self):
|
||||||
|
"""When prompt alone exceeds sequence_len, sample is returned as-is."""
|
||||||
|
sample = _make_dpo_sample(prompt_len=15, chosen_len=3, rejected_len=3)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertEqual(result, sample)
|
||||||
|
|
||||||
|
def test_kto_truncates_completion(self):
|
||||||
|
sample = _make_kto_sample(prompt_len=5, completion_len=10)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.KTO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
completion_len = len(
|
||||||
|
self.tokenizer(result["completion"], add_special_tokens=False)["input_ids"]
|
||||||
|
)
|
||||||
|
self.assertEqual(completion_len, 5)
|
||||||
|
|
||||||
|
def test_kto_short_sample_unchanged(self):
|
||||||
|
sample = _make_kto_sample(prompt_len=3, completion_len=2)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.KTO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertEqual(result["completion"], sample["completion"])
|
||||||
|
|
||||||
|
def test_kto_prompt_exceeds_limit_returns_unchanged(self):
|
||||||
|
sample = _make_kto_sample(prompt_len=15, completion_len=3)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.KTO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertEqual(result, sample)
|
||||||
|
|
||||||
|
def test_grpo_unchanged(self):
|
||||||
|
sample = {"prompt": "a " * 100}
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.GRPO, self.tokenizer, sequence_len=5
|
||||||
|
)
|
||||||
|
self.assertEqual(result, sample)
|
||||||
|
|
||||||
|
def test_ipo_uses_dpo_logic(self):
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=3)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.IPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
chosen_len = len(
|
||||||
|
self.tokenizer(result["chosen"], add_special_tokens=False)["input_ids"]
|
||||||
|
)
|
||||||
|
self.assertEqual(chosen_len, 5)
|
||||||
|
|
||||||
|
def test_does_not_mutate_original(self):
|
||||||
|
"""Verify immutability — original sample dict is not modified."""
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=10, rejected_len=10)
|
||||||
|
original_chosen = sample["chosen"]
|
||||||
|
original_rejected = sample["rejected"]
|
||||||
|
_truncate_long_sequences_rl(sample, RLType.DPO, self.tokenizer, sequence_len=10)
|
||||||
|
self.assertEqual(sample["chosen"], original_chosen)
|
||||||
|
self.assertEqual(sample["rejected"], original_rejected)
|
||||||
|
|
||||||
|
def test_dpo_missing_keys_raises(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_truncate_long_sequences_rl(
|
||||||
|
{"prompt": "hi"}, RLType.DPO, self.tokenizer, 10
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_kto_missing_keys_raises(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_truncate_long_sequences_rl(
|
||||||
|
{"prompt": "hi"}, RLType.KTO, self.tokenizer, 10
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_boundary_no_truncation_needed(self):
|
||||||
|
"""Samples exactly at sequence_len should not be modified."""
|
||||||
|
sample = _make_dpo_sample(prompt_len=5, chosen_len=5, rejected_len=5)
|
||||||
|
result = _truncate_long_sequences_rl(
|
||||||
|
sample, RLType.DPO, self.tokenizer, sequence_len=10
|
||||||
|
)
|
||||||
|
self.assertEqual(result["chosen"], sample["chosen"])
|
||||||
|
self.assertEqual(result["rejected"], sample["rejected"])
|
||||||
Reference in New Issue
Block a user