pre-commit: fix rl.py imports/types; add legacy drop_long_rl_seq wrapper; resolve config schema; run formatting

This commit is contained in:
mhenrhcsen
2025-08-12 21:12:07 +02:00
parent 54b542d312
commit dc5887c652
2 changed files with 16 additions and 141 deletions

View File

@@ -2,7 +2,7 @@
import inspect import inspect
from functools import partial from functools import partial
from typing import Any, Callable, Literal from typing import Any, Callable, Literal, List, Union
from datasets import Dataset, DatasetDict from datasets import Dataset, DatasetDict
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@@ -120,6 +120,13 @@ def _map_dataset(
) )
return dataset return dataset
def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"):
"""
Backward-compatibility wrapper for legacy imports in tests.
Delegates to the new predicate.
"""
return _drop_long_sequences(sample, rl, tokenizer, sequence_len)
def _drop_long_sequences( def _drop_long_sequences(
@@ -260,9 +267,7 @@ def load_prepare_preference_datasets(cfg):
""" """
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO): if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
if not ( if not (
sample.get("prompt") sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
and sample.get("chosen")
and sample.get("rejected")
): ):
return False return False
prompt = sample["prompt"] prompt = sample["prompt"]
@@ -270,7 +275,9 @@ def load_prepare_preference_datasets(cfg):
rejected = sample["rejected"] rejected = sample["rejected"]
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"]) len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"]) len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"]) len_rejected = len(
tokenizer(rejected, add_special_tokens=False)["input_ids"]
)
return (len_prompt + len_chosen) <= sequence_len and ( return (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected len_prompt + len_rejected
) <= sequence_len ) <= sequence_len
@@ -288,6 +295,7 @@ def load_prepare_preference_datasets(cfg):
# GRPO does not enforce this check here # GRPO does not enforce this check here
return True return True
return False return False
def load_split(dataset_cfgs, _cfg): def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = [] split_datasets: List[Any] = []
use_auth_token = _cfg.hf_use_auth_token use_auth_token = _cfg.hf_use_auth_token
@@ -296,6 +304,8 @@ def load_prepare_preference_datasets(cfg):
config_dataset, use_auth_token, streaming=False config_dataset, use_auth_token, streaming=False
) )
split_datasets.append(ds) split_datasets.append(ds)
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset: def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
"""Load and process dataset split for RL training. """Load and process dataset split for RL training.
@@ -309,141 +319,6 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
split_datasets: list[Dataset | DatasetDict] = [] split_datasets: list[Dataset | DatasetDict] = []
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
)
elif _cfg.rl is RLType.KTO:
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
)
else:
# If no `type` is provided, assume the dataset is already in the expected format with
# "prompt", "chosen" and "rejected" already preprocessed
split_datasets[i] = data_set
if not cfg.skip_prepare_dataset:
# Determine handling mode
# Support legacy alias "excess_token_handling" for compatibility
handling = cfg.get(
"sequence_len_overflow_handling",
cfg.get("excess_token_handling", "drop"),
)
drop_long = partial(
drop_long_rl_seq,
rl=_cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
handling=handling, # Pass the handling mode
)
prior_len = len(split_datasets[i])
# Use map for truncate mode and filter for drop mode
if handling == "truncate":
split_datasets[i] = split_datasets[i].map(
drop_long, # Function now returns modified sample or original
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Truncating Long Sequences",
)
# After truncation, drop any samples that still exceed sequence_len (e.g., prompt alone too long)
split_datasets[i] = split_datasets[i].filter(
partial(
_is_rl_seq_within_sequence_len,
rl=_cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
),
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Oversize Samples After Truncation",
)
LOG.info(
f"Processed dataset index {i} with truncation handling for sequence length {cfg.sequence_len}"
)
else: # handling == "drop"
split_datasets[i] = split_datasets[i].filter(
drop_long, # Function now returns boolean
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(
f"Dropped {dropped} long samples from dataset index {i}"
)
combined_datasets = concatenate_datasets(split_datasets)
combined_datasets = combined_datasets.shuffle(seed=cfg.seed or 42)
return combined_datasets
with zero_first(is_main_process()):
train_is_preprocessed = False
eval_is_preprocessed = False
if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
train_is_preprocessed = True
else:
train_dataset = load_split(cfg.datasets, cfg)
eval_dataset = None
if cfg.test_datasets:
if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
eval_is_preprocessed = True
else:
eval_dataset = load_split(cfg.test_datasets, cfg)
if not eval_dataset:
if cfg.val_set_size:
seed = cfg.seed if cfg.seed is not None else 42
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
to_hash_train = (
train_dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ "|"
+ "train"
+ "|"
+ str(seed)
)
to_hash_test = (
train_dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ "|"
+ "test"
+ "|"
+ str(seed)
)
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)
ds_w_test_split = train_dataset.train_test_split(
test_size=cfg.val_set_size,
seed=seed,
shuffle=False,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
eval_dataset = ds_w_test_split["test"]
train_dataset = ds_w_test_split["train"]
if not train_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
if eval_dataset and not eval_is_preprocessed:
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
if cfg.dataset_exact_deduplication:
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
train_dataset=train_dataset, eval_dataset=eval_dataset
for dataset_config in datasets_with_name_generator(datasets_configs): for dataset_config in datasets_with_name_generator(datasets_configs):
dataset: Dataset | DatasetDict = load_dataset_with_config( dataset: Dataset | DatasetDict = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=False dataset_config, cfg.hf_use_auth_token, streaming=False

View File

@@ -424,7 +424,7 @@ class AxolotlInputConfig(
default=None, default=None,
json_schema_extra={ json_schema_extra={
"description": "The maximum length of an input for evaluation. If not specified, defaults to sequence_len" "description": "The maximum length of an input for evaluation. If not specified, defaults to sequence_len"
} },
) )
min_sample_len: int | None = None min_sample_len: int | None = None
max_prompt_len: int = Field( max_prompt_len: int = Field(