pre-commit: fix rl.py imports/types; add legacy drop_long_rl_seq wrapper; resolve config schema; run formatting

This commit is contained in:
mhenrhcsen
2025-08-12 21:12:07 +02:00
parent 54b542d312
commit dc5887c652
2 changed files with 16 additions and 141 deletions

View File

@@ -2,7 +2,7 @@
import inspect
from functools import partial
from typing import Any, Callable, Literal
from typing import Any, Callable, Literal, List, Union
from datasets import Dataset, DatasetDict
from transformers import PreTrainedTokenizer
@@ -120,6 +120,13 @@ def _map_dataset(
)
return dataset
def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"):
"""
Backward-compatibility wrapper for legacy imports in tests.
Delegates to the new predicate.
"""
return _drop_long_sequences(sample, rl, tokenizer, sequence_len)
def _drop_long_sequences(
@@ -260,9 +267,7 @@ def load_prepare_preference_datasets(cfg):
"""
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
if not (
sample.get("prompt")
and sample.get("chosen")
and sample.get("rejected")
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
):
return False
prompt = sample["prompt"]
@@ -270,7 +275,9 @@ def load_prepare_preference_datasets(cfg):
rejected = sample["rejected"]
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
len_rejected = len(
tokenizer(rejected, add_special_tokens=False)["input_ids"]
)
return (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len
@@ -288,6 +295,7 @@ def load_prepare_preference_datasets(cfg):
# GRPO does not enforce this check here
return True
return False
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
use_auth_token = _cfg.hf_use_auth_token
@@ -296,6 +304,8 @@ def load_prepare_preference_datasets(cfg):
config_dataset, use_auth_token, streaming=False
)
split_datasets.append(ds)
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
"""Load and process dataset split for RL training.
@@ -309,141 +319,6 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
split_datasets: list[Dataset | DatasetDict] = []
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
)
elif _cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
)
else:
# If no `type` is provided, assume the dataset is already in the expected format with
# "prompt", "chosen" and "rejected" already preprocessed
split_datasets[i] = data_set
if not cfg.skip_prepare_dataset:
# Determine handling mode
# Support legacy alias "excess_token_handling" for compatibility
handling = cfg.get(
"sequence_len_overflow_handling",
cfg.get("excess_token_handling", "drop"),
)
drop_long = partial(
drop_long_rl_seq,
rl=_cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
handling=handling, # Pass the handling mode
)
prior_len = len(split_datasets[i])
# Use map for truncate mode and filter for drop mode
if handling == "truncate":
split_datasets[i] = split_datasets[i].map(
drop_long, # Function now returns modified sample or original
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Truncating Long Sequences",
)
# After truncation, drop any samples that still exceed sequence_len (e.g., prompt alone too long)
split_datasets[i] = split_datasets[i].filter(
partial(
_is_rl_seq_within_sequence_len,
rl=_cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
),
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Oversize Samples After Truncation",
)
LOG.info(
f"Processed dataset index {i} with truncation handling for sequence length {cfg.sequence_len}"
)
else: # handling == "drop"
split_datasets[i] = split_datasets[i].filter(
drop_long, # Function now returns boolean
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(
f"Dropped {dropped} long samples from dataset index {i}"
)
combined_datasets = concatenate_datasets(split_datasets)
combined_datasets = combined_datasets.shuffle(seed=cfg.seed or 42)
return combined_datasets
with zero_first(is_main_process()):
train_is_preprocessed = False
eval_is_preprocessed = False
if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
train_is_preprocessed = True
else:
train_dataset = load_split(cfg.datasets, cfg)
eval_dataset = None
if cfg.test_datasets:
if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
eval_is_preprocessed = True
else:
eval_dataset = load_split(cfg.test_datasets, cfg)
if not eval_dataset:
if cfg.val_set_size:
seed = cfg.seed if cfg.seed is not None else 42
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
to_hash_train = (
train_dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ "|"
+ "train"
+ "|"
+ str(seed)
)
to_hash_test = (
train_dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ "|"
+ "test"
+ "|"
+ str(seed)
)
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)
ds_w_test_split = train_dataset.train_test_split(
test_size=cfg.val_set_size,
seed=seed,
shuffle=False,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
eval_dataset = ds_w_test_split["test"]
train_dataset = ds_w_test_split["train"]
if not train_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
if eval_dataset and not eval_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
if cfg.dataset_exact_deduplication:
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=train_dataset, eval_dataset=eval_dataset
for dataset_config in datasets_with_name_generator(datasets_configs):
dataset: Dataset | DatasetDict = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=False

View File

@@ -424,7 +424,7 @@ class AxolotlInputConfig(
default=None,
json_schema_extra={
"description": "The maximum length of an input for evaluation. If not specified, defaults to sequence_len"
}
},
)
min_sample_len: int | None = None
max_prompt_len: int = Field(