feat: support excess_length_strategy for RL trainers (#3578) [skip ci]
* feat: support excess_length_strategy for RL trainers Previously, RL data loading always dropped sequences exceeding sequence_len. This adds support for the existing `excess_length_strategy` config option (`drop`, `truncate`, `raise`) in RL training pipelines, matching the behavior already available for SFT. - `drop` (default): unchanged behavior, filters out long samples - `truncate`: tokenizes text components, truncates responses to fit within sequence_len while preserving the full prompt, then decodes back to text. Handles DPO/IPO/ORPO/SIMPO and KTO datasets. - `raise`: raises ValueError if any sample exceeds sequence_len Closes #3547 * improve RL truncation strategy robustness and performance --------- Co-authored-by: yurekami <yurekami@users.noreply.github.com> Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -180,6 +180,119 @@ def _drop_long_sequences(
|
||||
raise ValueError("Unknown RL type")
|
||||
|
||||
|
||||
def _raise_on_long_sequences(
|
||||
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
|
||||
) -> bool:
|
||||
"""Check sequence length and raise ValueError if exceeded.
|
||||
|
||||
Used as a filter function for ``excess_length_strategy: raise``.
|
||||
|
||||
Args:
|
||||
sample: Dataset sample to check.
|
||||
rl: Reinforcement learning type.
|
||||
tokenizer: Tokenizer for length calculation.
|
||||
sequence_len: Maximum allowed sequence length.
|
||||
|
||||
Returns:
|
||||
Always True (raises before returning False).
|
||||
|
||||
Raises:
|
||||
ValueError: If any sample exceeds the configured sequence length.
|
||||
"""
|
||||
is_valid = _drop_long_sequences(sample, rl, tokenizer, sequence_len)
|
||||
if not is_valid:
|
||||
raise ValueError(
|
||||
f"Sample exceeds configured sequence_len ({sequence_len}). "
|
||||
"Set `excess_length_strategy: drop` or `excess_length_strategy: truncate` "
|
||||
"to handle long sequences automatically."
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _truncate_long_sequences_rl(
|
||||
sample: dict[str, Any], rl: RLType, tokenizer: Any, sequence_len: int
|
||||
) -> dict[str, Any]:
|
||||
"""Truncate RL samples that exceed maximum sequence length.
|
||||
|
||||
For preference datasets (DPO/IPO/ORPO/SIMPO), truncates chosen and rejected
|
||||
responses to fit within ``sequence_len`` when combined with the prompt.
|
||||
For KTO, truncates the completion similarly.
|
||||
GRPO/GDPO/EBFT samples are returned unchanged.
|
||||
|
||||
Samples where the prompt alone exceeds ``sequence_len`` cannot be
|
||||
meaningfully truncated and are returned unchanged. The caller should
|
||||
follow up with a drop filter to remove them.
|
||||
|
||||
Args:
|
||||
sample: Dataset sample to potentially truncate.
|
||||
rl: Reinforcement learning type.
|
||||
tokenizer: Tokenizer for encoding/decoding.
|
||||
sequence_len: Maximum allowed sequence length.
|
||||
|
||||
Returns:
|
||||
The sample with text fields truncated to fit within sequence_len.
|
||||
"""
|
||||
# Fast path: if sample already fits, return unchanged (avoids decode overhead)
|
||||
if _drop_long_sequences(sample, rl, tokenizer, sequence_len):
|
||||
return sample
|
||||
|
||||
if rl in {RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO}:
|
||||
if not (
|
||||
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
|
||||
):
|
||||
raise ValueError(
|
||||
"Prompt, chosen and rejected keys are required for DPO/ORPO datasets"
|
||||
)
|
||||
|
||||
prompt_ids = tokenizer(sample["prompt"], add_special_tokens=False)["input_ids"]
|
||||
chosen_ids = tokenizer(sample["chosen"], add_special_tokens=False)["input_ids"]
|
||||
rejected_ids = tokenizer(sample["rejected"], add_special_tokens=False)[
|
||||
"input_ids"
|
||||
]
|
||||
|
||||
max_response_len = sequence_len - len(prompt_ids)
|
||||
if max_response_len <= 0:
|
||||
# Prompt alone exceeds limit; cannot meaningfully truncate.
|
||||
# Returned unchanged — the follow-up drop filter will remove it.
|
||||
return sample
|
||||
|
||||
updates: dict[str, Any] = {}
|
||||
if len(chosen_ids) > max_response_len:
|
||||
updates["chosen"] = tokenizer.decode(
|
||||
chosen_ids[:max_response_len], skip_special_tokens=False
|
||||
)
|
||||
if len(rejected_ids) > max_response_len:
|
||||
updates["rejected"] = tokenizer.decode(
|
||||
rejected_ids[:max_response_len], skip_special_tokens=False
|
||||
)
|
||||
if updates:
|
||||
sample = {**sample, **updates}
|
||||
|
||||
elif rl is RLType.KTO:
|
||||
if not (sample.get("prompt") and sample.get("completion")):
|
||||
raise ValueError("Prompt and completion keys are required for KTO datasets")
|
||||
|
||||
prompt_ids = tokenizer(sample["prompt"], add_special_tokens=False)["input_ids"]
|
||||
completion_ids = tokenizer(sample["completion"], add_special_tokens=False)[
|
||||
"input_ids"
|
||||
]
|
||||
|
||||
max_completion_len = sequence_len - len(prompt_ids)
|
||||
if max_completion_len <= 0:
|
||||
return sample
|
||||
|
||||
if len(completion_ids) > max_completion_len:
|
||||
sample = {
|
||||
**sample,
|
||||
"completion": tokenizer.decode(
|
||||
completion_ids[:max_completion_len], skip_special_tokens=False
|
||||
),
|
||||
}
|
||||
|
||||
# GRPO/GDPO/EBFT: no truncation needed (responses generated at runtime)
|
||||
return sample
|
||||
|
||||
|
||||
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||
"""Load and process dataset split for RL training.
|
||||
|
||||
@@ -243,23 +356,77 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||
split_datasets[i] = dataset
|
||||
|
||||
if not cfg.skip_prepare_dataset:
|
||||
drop_long = partial(
|
||||
_drop_long_sequences,
|
||||
rl=cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
|
||||
|
||||
prior_len = len(split_datasets[i])
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_num_proc,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
dropped = prior_len - len(split_datasets[i])
|
||||
if dropped:
|
||||
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
|
||||
if excess_length_strategy == "truncate":
|
||||
truncate_fn = partial(
|
||||
_truncate_long_sequences_rl,
|
||||
rl=cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
prior_len = len(split_datasets[i])
|
||||
split_datasets[i] = split_datasets[i].map(
|
||||
truncate_fn,
|
||||
num_proc=cfg.dataset_num_proc,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Truncating Long Sequences",
|
||||
)
|
||||
|
||||
# Drop samples that could not be truncated (e.g. prompt
|
||||
# alone exceeds sequence_len)
|
||||
drop_long = partial(
|
||||
_drop_long_sequences,
|
||||
rl=cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_num_proc,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Un-truncatable Sequences",
|
||||
)
|
||||
dropped = prior_len - len(split_datasets[i])
|
||||
if dropped:
|
||||
LOG.warning(
|
||||
f"Dropped {dropped} samples from dataset index {i} "
|
||||
f"that could not be truncated to fit sequence_len "
|
||||
f"(prompt alone exceeds limit)"
|
||||
)
|
||||
elif excess_length_strategy == "raise":
|
||||
raise_fn = partial(
|
||||
_raise_on_long_sequences,
|
||||
rl=cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
raise_fn,
|
||||
num_proc=cfg.dataset_num_proc,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Checking Sequence Lengths",
|
||||
)
|
||||
else: # "drop" (default)
|
||||
drop_long = partial(
|
||||
_drop_long_sequences,
|
||||
rl=cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
)
|
||||
|
||||
prior_len = len(split_datasets[i])
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long,
|
||||
num_proc=cfg.dataset_num_proc,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
dropped = prior_len - len(split_datasets[i])
|
||||
if dropped:
|
||||
LOG.warning(
|
||||
f"Dropped {dropped} long samples from dataset index {i}"
|
||||
)
|
||||
|
||||
# Merge datasets
|
||||
dataset = merge_datasets(split_datasets, cfg)
|
||||
|
||||
Reference in New Issue
Block a user