pre-commit: fix rl.py imports/types; add legacy drop_long_rl_seq wrapper; resolve config schema; run formatting
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
import inspect
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Literal
|
||||
from typing import Any, Callable, Literal, List, Union
|
||||
|
||||
from datasets import Dataset, DatasetDict
|
||||
from transformers import PreTrainedTokenizer
|
||||
@@ -120,6 +120,13 @@ def _map_dataset(
|
||||
)
|
||||
|
||||
return dataset
|
||||
def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"):
|
||||
"""
|
||||
Backward-compatibility wrapper for legacy imports in tests.
|
||||
Delegates to the new predicate.
|
||||
"""
|
||||
return _drop_long_sequences(sample, rl, tokenizer, sequence_len)
|
||||
|
||||
|
||||
|
||||
def _drop_long_sequences(
|
||||
@@ -260,9 +267,7 @@ def load_prepare_preference_datasets(cfg):
|
||||
"""
|
||||
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
|
||||
if not (
|
||||
sample.get("prompt")
|
||||
and sample.get("chosen")
|
||||
and sample.get("rejected")
|
||||
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
|
||||
):
|
||||
return False
|
||||
prompt = sample["prompt"]
|
||||
@@ -270,7 +275,9 @@ def load_prepare_preference_datasets(cfg):
|
||||
rejected = sample["rejected"]
|
||||
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
|
||||
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
|
||||
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
||||
len_rejected = len(
|
||||
tokenizer(rejected, add_special_tokens=False)["input_ids"]
|
||||
)
|
||||
return (len_prompt + len_chosen) <= sequence_len and (
|
||||
len_prompt + len_rejected
|
||||
) <= sequence_len
|
||||
@@ -288,6 +295,7 @@ def load_prepare_preference_datasets(cfg):
|
||||
# GRPO does not enforce this check here
|
||||
return True
|
||||
return False
|
||||
|
||||
def load_split(dataset_cfgs, _cfg):
|
||||
split_datasets: List[Any] = []
|
||||
use_auth_token = _cfg.hf_use_auth_token
|
||||
@@ -296,6 +304,8 @@ def load_prepare_preference_datasets(cfg):
|
||||
config_dataset, use_auth_token, streaming=False
|
||||
)
|
||||
split_datasets.append(ds)
|
||||
|
||||
|
||||
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||
"""Load and process dataset split for RL training.
|
||||
|
||||
@@ -309,141 +319,6 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
|
||||
split_datasets: list[Dataset | DatasetDict] = []
|
||||
|
||||
map_kwargs = {}
|
||||
if isinstance(ds_transform_fn, tuple):
|
||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||
split_datasets[i] = map_dataset(
|
||||
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||
)
|
||||
elif _cfg.rl is RLType.KTO:
|
||||
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
||||
map_kwargs = {}
|
||||
if isinstance(ds_transform_fn, tuple):
|
||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||
split_datasets[i] = map_dataset(
|
||||
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||
)
|
||||
else:
|
||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||
# "prompt", "chosen" and "rejected" already preprocessed
|
||||
split_datasets[i] = data_set
|
||||
|
||||
if not cfg.skip_prepare_dataset:
|
||||
# Determine handling mode
|
||||
# Support legacy alias "excess_token_handling" for compatibility
|
||||
handling = cfg.get(
|
||||
"sequence_len_overflow_handling",
|
||||
cfg.get("excess_token_handling", "drop"),
|
||||
)
|
||||
|
||||
drop_long = partial(
|
||||
drop_long_rl_seq,
|
||||
rl=_cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
handling=handling, # Pass the handling mode
|
||||
)
|
||||
|
||||
prior_len = len(split_datasets[i])
|
||||
|
||||
# Use map for truncate mode and filter for drop mode
|
||||
if handling == "truncate":
|
||||
split_datasets[i] = split_datasets[i].map(
|
||||
drop_long, # Function now returns modified sample or original
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Truncating Long Sequences",
|
||||
)
|
||||
# After truncation, drop any samples that still exceed sequence_len (e.g., prompt alone too long)
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
partial(
|
||||
_is_rl_seq_within_sequence_len,
|
||||
rl=_cfg.rl,
|
||||
tokenizer=tokenizer,
|
||||
sequence_len=cfg.sequence_len,
|
||||
),
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Oversize Samples After Truncation",
|
||||
)
|
||||
LOG.info(
|
||||
f"Processed dataset index {i} with truncation handling for sequence length {cfg.sequence_len}"
|
||||
)
|
||||
else: # handling == "drop"
|
||||
split_datasets[i] = split_datasets[i].filter(
|
||||
drop_long, # Function now returns boolean
|
||||
num_proc=cfg.dataset_processes,
|
||||
load_from_cache_file=not cfg.is_preprocess,
|
||||
desc="Dropping Long Sequences",
|
||||
)
|
||||
dropped = prior_len - len(split_datasets[i])
|
||||
if dropped:
|
||||
LOG.warning(
|
||||
f"Dropped {dropped} long samples from dataset index {i}"
|
||||
)
|
||||
|
||||
combined_datasets = concatenate_datasets(split_datasets)
|
||||
combined_datasets = combined_datasets.shuffle(seed=cfg.seed or 42)
|
||||
|
||||
return combined_datasets
|
||||
|
||||
with zero_first(is_main_process()):
|
||||
train_is_preprocessed = False
|
||||
eval_is_preprocessed = False
|
||||
if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
|
||||
train_is_preprocessed = True
|
||||
else:
|
||||
train_dataset = load_split(cfg.datasets, cfg)
|
||||
|
||||
eval_dataset = None
|
||||
if cfg.test_datasets:
|
||||
if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
|
||||
eval_is_preprocessed = True
|
||||
else:
|
||||
eval_dataset = load_split(cfg.test_datasets, cfg)
|
||||
if not eval_dataset:
|
||||
if cfg.val_set_size:
|
||||
seed = cfg.seed if cfg.seed is not None else 42
|
||||
|
||||
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
||||
to_hash_train = (
|
||||
train_dataset._fingerprint # pylint: disable=protected-access
|
||||
+ "|"
|
||||
+ str(cfg.val_set_size)
|
||||
+ "|"
|
||||
+ "train"
|
||||
+ "|"
|
||||
+ str(seed)
|
||||
)
|
||||
to_hash_test = (
|
||||
train_dataset._fingerprint # pylint: disable=protected-access
|
||||
+ "|"
|
||||
+ str(cfg.val_set_size)
|
||||
+ "|"
|
||||
+ "test"
|
||||
+ "|"
|
||||
+ str(seed)
|
||||
)
|
||||
train_fingerprint = md5(to_hash_train)
|
||||
test_fingerprint = md5(to_hash_test)
|
||||
ds_w_test_split = train_dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
seed=seed,
|
||||
shuffle=False,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
eval_dataset = ds_w_test_split["test"]
|
||||
train_dataset = ds_w_test_split["train"]
|
||||
|
||||
if not train_is_preprocessed:
|
||||
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
|
||||
if eval_dataset and not eval_is_preprocessed:
|
||||
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
|
||||
|
||||
if cfg.dataset_exact_deduplication:
|
||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
||||
train_dataset=train_dataset, eval_dataset=eval_dataset
|
||||
for dataset_config in datasets_with_name_generator(datasets_configs):
|
||||
dataset: Dataset | DatasetDict = load_dataset_with_config(
|
||||
dataset_config, cfg.hf_use_auth_token, streaming=False
|
||||
|
||||
@@ -424,7 +424,7 @@ class AxolotlInputConfig(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "The maximum length of an input for evaluation. If not specified, defaults to sequence_len"
|
||||
}
|
||||
},
|
||||
)
|
||||
min_sample_len: int | None = None
|
||||
max_prompt_len: int = Field(
|
||||
|
||||
Reference in New Issue
Block a user