feat: support excess_length_strategy for RL trainers (#3578) [skip ci]

* feat: support excess_length_strategy for RL trainers

Previously, RL data loading always dropped sequences exceeding
sequence_len. This adds support for the existing `excess_length_strategy`
config option (`drop`, `truncate`, `raise`) in RL training pipelines,
matching the behavior already available for SFT.

- `drop` (default): unchanged behavior, filters out long samples
- `truncate`: tokenizes text components, truncates responses to fit
  within sequence_len while preserving the full prompt, then decodes
  back to text. Handles DPO/IPO/ORPO/SIMPO and KTO datasets.
- `raise`: raises ValueError if any sample exceeds sequence_len

Closes #3547

* improve RL truncation strategy robustness and performance

---------

Co-authored-by: yurekami <yurekami@users.noreply.github.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
ゆり
2026-04-13 08:51:10 +08:00
committed by GitHub
parent 3985ec2f67
commit 63a58cfec1
2 changed files with 475 additions and 16 deletions

View File

@@ -180,6 +180,119 @@ def _drop_long_sequences(
raise ValueError("Unknown RL type")
def _raise_on_long_sequences(
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
) -> bool:
"""Check sequence length and raise ValueError if exceeded.
Used as a filter function for ``excess_length_strategy: raise``.
Args:
sample: Dataset sample to check.
rl: Reinforcement learning type.
tokenizer: Tokenizer for length calculation.
sequence_len: Maximum allowed sequence length.
Returns:
Always True (raises before returning False).
Raises:
ValueError: If any sample exceeds the configured sequence length.
"""
is_valid = _drop_long_sequences(sample, rl, tokenizer, sequence_len)
if not is_valid:
raise ValueError(
f"Sample exceeds configured sequence_len ({sequence_len}). "
"Set `excess_length_strategy: drop` or `excess_length_strategy: truncate` "
"to handle long sequences automatically."
)
return True
def _truncate_long_sequences_rl(
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
) -> dict[str, Any]:
"""Truncate RL samples that exceed maximum sequence length.
For preference datasets (DPO/IPO/ORPO/SIMPO), truncates chosen and rejected
responses to fit within ``sequence_len`` when combined with the prompt.
For KTO, truncates the completion similarly.
GRPO/GDPO/EBFT samples are returned unchanged.
Samples where the prompt alone exceeds ``sequence_len`` cannot be
meaningfully truncated and are returned unchanged. The caller should
follow up with a drop filter to remove them.
Args:
sample: Dataset sample to potentially truncate.
rl: Reinforcement learning type.
tokenizer: Tokenizer for encoding/decoding.
sequence_len: Maximum allowed sequence length.
Returns:
The sample with text fields truncated to fit within sequence_len.
"""
# Fast path: if sample already fits, return unchanged (avoids decode overhead)
if _drop_long_sequences(sample, rl, tokenizer, sequence_len):
return sample
if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}:
if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
):
raise ValueError(
"Prompt, chosen and rejected keys are required for DPO/ORPO datasets"
)
prompt_ids = tokenizer(sample["prompt"], add_special_tokens=False)["input_ids"]
chosen_ids = tokenizer(sample["chosen"], add_special_tokens=False)["input_ids"]
rejected_ids = tokenizer(sample["rejected"], add_special_tokens=False)[
"input_ids"
]
max_response_len = sequence_len - len(prompt_ids)
if max_response_len <= 0:
# Prompt alone exceeds limit; cannot meaningfully truncate.
# Returned unchanged — the follow-up drop filter will remove it.
return sample
updates: dict[str, Any] = {}
if len(chosen_ids) > max_response_len:
updates["chosen"] = tokenizer.decode(
chosen_ids[:max_response_len], skip_special_tokens=False
)
if len(rejected_ids) > max_response_len:
updates["rejected"] = tokenizer.decode(
rejected_ids[:max_response_len], skip_special_tokens=False
)
if updates:
sample = {**sample, **updates}
elif rl is RLType.KTO:
if not (sample.get("prompt") and sample.get("completion")):
raise ValueError("Prompt and completion keys are required for KTO datasets")
prompt_ids = tokenizer(sample["prompt"], add_special_tokens=False)["input_ids"]
completion_ids = tokenizer(sample["completion"], add_special_tokens=False)[
"input_ids"
]
max_completion_len = sequence_len - len(prompt_ids)
if max_completion_len <= 0:
return sample
if len(completion_ids) > max_completion_len:
sample = {
**sample,
"completion": tokenizer.decode(
completion_ids[:max_completion_len], skip_special_tokens=False
),
}
# GRPO/GDPO/EBFT: no truncation needed (responses generated at runtime)
return sample
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
"""Load and process dataset split for RL training.
@@ -243,23 +356,77 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
split_datasets[i] = dataset
if not cfg.skip_prepare_dataset:
drop_long = partial(
_drop_long_sequences,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
if excess_length_strategy == "truncate":
truncate_fn = partial(
_truncate_long_sequences_rl,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].map(
truncate_fn,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Truncating Long Sequences",
)
# Drop samples that could not be truncated (e.g. prompt
# alone exceeds sequence_len)
drop_long = partial(
_drop_long_sequences,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Un-truncatable Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(
f"Dropped {dropped} samples from dataset index {i} "
f"that could not be truncated to fit sequence_len "
f"(prompt alone exceeds limit)"
)
elif excess_length_strategy == "raise":
raise_fn = partial(
_raise_on_long_sequences,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
split_datasets[i] = split_datasets[i].filter(
raise_fn,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Checking Sequence Lengths",
)
else: # "drop" (default)
drop_long = partial(
_drop_long_sequences,
rl=cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_num_proc,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(
f"Dropped {dropped} long samples from dataset index {i}"
)
# Merge datasets
dataset = merge_datasets(split_datasets, cfg)