pre-commit: fix rl.py imports/types; add legacy drop_long_rl_seq wrapper; resolve config schema; run formatting
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Literal
|
from typing import Any, Callable, Literal, List, Union
|
||||||
|
|
||||||
from datasets import Dataset, DatasetDict
|
from datasets import Dataset, DatasetDict
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
@@ -120,6 +120,13 @@ def _map_dataset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"):
|
||||||
|
"""
|
||||||
|
Backward-compatibility wrapper for legacy imports in tests.
|
||||||
|
Delegates to the new predicate.
|
||||||
|
"""
|
||||||
|
return _drop_long_sequences(sample, rl, tokenizer, sequence_len)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _drop_long_sequences(
|
def _drop_long_sequences(
|
||||||
@@ -260,9 +267,7 @@ def load_prepare_preference_datasets(cfg):
|
|||||||
"""
|
"""
|
||||||
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
|
if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO):
|
||||||
if not (
|
if not (
|
||||||
sample.get("prompt")
|
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
|
||||||
and sample.get("chosen")
|
|
||||||
and sample.get("rejected")
|
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
prompt = sample["prompt"]
|
prompt = sample["prompt"]
|
||||||
@@ -270,7 +275,9 @@ def load_prepare_preference_datasets(cfg):
|
|||||||
rejected = sample["rejected"]
|
rejected = sample["rejected"]
|
||||||
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
|
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
|
||||||
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
|
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
|
||||||
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
|
len_rejected = len(
|
||||||
|
tokenizer(rejected, add_special_tokens=False)["input_ids"]
|
||||||
|
)
|
||||||
return (len_prompt + len_chosen) <= sequence_len and (
|
return (len_prompt + len_chosen) <= sequence_len and (
|
||||||
len_prompt + len_rejected
|
len_prompt + len_rejected
|
||||||
) <= sequence_len
|
) <= sequence_len
|
||||||
@@ -288,6 +295,7 @@ def load_prepare_preference_datasets(cfg):
|
|||||||
# GRPO does not enforce this check here
|
# GRPO does not enforce this check here
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def load_split(dataset_cfgs, _cfg):
|
def load_split(dataset_cfgs, _cfg):
|
||||||
split_datasets: List[Any] = []
|
split_datasets: List[Any] = []
|
||||||
use_auth_token = _cfg.hf_use_auth_token
|
use_auth_token = _cfg.hf_use_auth_token
|
||||||
@@ -296,6 +304,8 @@ def load_prepare_preference_datasets(cfg):
|
|||||||
config_dataset, use_auth_token, streaming=False
|
config_dataset, use_auth_token, streaming=False
|
||||||
)
|
)
|
||||||
split_datasets.append(ds)
|
split_datasets.append(ds)
|
||||||
|
|
||||||
|
|
||||||
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
||||||
"""Load and process dataset split for RL training.
|
"""Load and process dataset split for RL training.
|
||||||
|
|
||||||
@@ -309,141 +319,6 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset:
|
|||||||
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
|
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
|
||||||
split_datasets: list[Dataset | DatasetDict] = []
|
split_datasets: list[Dataset | DatasetDict] = []
|
||||||
|
|
||||||
map_kwargs = {}
|
|
||||||
if isinstance(ds_transform_fn, tuple):
|
|
||||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
|
||||||
split_datasets[i] = map_dataset(
|
|
||||||
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
|
||||||
)
|
|
||||||
elif _cfg.rl is RLType.KTO:
|
|
||||||
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
|
||||||
map_kwargs = {}
|
|
||||||
if isinstance(ds_transform_fn, tuple):
|
|
||||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
|
||||||
split_datasets[i] = map_dataset(
|
|
||||||
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
|
||||||
# "prompt", "chosen" and "rejected" already preprocessed
|
|
||||||
split_datasets[i] = data_set
|
|
||||||
|
|
||||||
if not cfg.skip_prepare_dataset:
|
|
||||||
# Determine handling mode
|
|
||||||
# Support legacy alias "excess_token_handling" for compatibility
|
|
||||||
handling = cfg.get(
|
|
||||||
"sequence_len_overflow_handling",
|
|
||||||
cfg.get("excess_token_handling", "drop"),
|
|
||||||
)
|
|
||||||
|
|
||||||
drop_long = partial(
|
|
||||||
drop_long_rl_seq,
|
|
||||||
rl=_cfg.rl,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
sequence_len=cfg.sequence_len,
|
|
||||||
handling=handling, # Pass the handling mode
|
|
||||||
)
|
|
||||||
|
|
||||||
prior_len = len(split_datasets[i])
|
|
||||||
|
|
||||||
# Use map for truncate mode and filter for drop mode
|
|
||||||
if handling == "truncate":
|
|
||||||
split_datasets[i] = split_datasets[i].map(
|
|
||||||
drop_long, # Function now returns modified sample or original
|
|
||||||
num_proc=cfg.dataset_processes,
|
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
|
||||||
desc="Truncating Long Sequences",
|
|
||||||
)
|
|
||||||
# After truncation, drop any samples that still exceed sequence_len (e.g., prompt alone too long)
|
|
||||||
split_datasets[i] = split_datasets[i].filter(
|
|
||||||
partial(
|
|
||||||
_is_rl_seq_within_sequence_len,
|
|
||||||
rl=_cfg.rl,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
sequence_len=cfg.sequence_len,
|
|
||||||
),
|
|
||||||
num_proc=cfg.dataset_processes,
|
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
|
||||||
desc="Dropping Oversize Samples After Truncation",
|
|
||||||
)
|
|
||||||
LOG.info(
|
|
||||||
f"Processed dataset index {i} with truncation handling for sequence length {cfg.sequence_len}"
|
|
||||||
)
|
|
||||||
else: # handling == "drop"
|
|
||||||
split_datasets[i] = split_datasets[i].filter(
|
|
||||||
drop_long, # Function now returns boolean
|
|
||||||
num_proc=cfg.dataset_processes,
|
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
|
||||||
desc="Dropping Long Sequences",
|
|
||||||
)
|
|
||||||
dropped = prior_len - len(split_datasets[i])
|
|
||||||
if dropped:
|
|
||||||
LOG.warning(
|
|
||||||
f"Dropped {dropped} long samples from dataset index {i}"
|
|
||||||
)
|
|
||||||
|
|
||||||
combined_datasets = concatenate_datasets(split_datasets)
|
|
||||||
combined_datasets = combined_datasets.shuffle(seed=cfg.seed or 42)
|
|
||||||
|
|
||||||
return combined_datasets
|
|
||||||
|
|
||||||
with zero_first(is_main_process()):
|
|
||||||
train_is_preprocessed = False
|
|
||||||
eval_is_preprocessed = False
|
|
||||||
if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
|
|
||||||
train_is_preprocessed = True
|
|
||||||
else:
|
|
||||||
train_dataset = load_split(cfg.datasets, cfg)
|
|
||||||
|
|
||||||
eval_dataset = None
|
|
||||||
if cfg.test_datasets:
|
|
||||||
if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
|
|
||||||
eval_is_preprocessed = True
|
|
||||||
else:
|
|
||||||
eval_dataset = load_split(cfg.test_datasets, cfg)
|
|
||||||
if not eval_dataset:
|
|
||||||
if cfg.val_set_size:
|
|
||||||
seed = cfg.seed if cfg.seed is not None else 42
|
|
||||||
|
|
||||||
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
|
||||||
to_hash_train = (
|
|
||||||
train_dataset._fingerprint # pylint: disable=protected-access
|
|
||||||
+ "|"
|
|
||||||
+ str(cfg.val_set_size)
|
|
||||||
+ "|"
|
|
||||||
+ "train"
|
|
||||||
+ "|"
|
|
||||||
+ str(seed)
|
|
||||||
)
|
|
||||||
to_hash_test = (
|
|
||||||
train_dataset._fingerprint # pylint: disable=protected-access
|
|
||||||
+ "|"
|
|
||||||
+ str(cfg.val_set_size)
|
|
||||||
+ "|"
|
|
||||||
+ "test"
|
|
||||||
+ "|"
|
|
||||||
+ str(seed)
|
|
||||||
)
|
|
||||||
train_fingerprint = md5(to_hash_train)
|
|
||||||
test_fingerprint = md5(to_hash_test)
|
|
||||||
ds_w_test_split = train_dataset.train_test_split(
|
|
||||||
test_size=cfg.val_set_size,
|
|
||||||
seed=seed,
|
|
||||||
shuffle=False,
|
|
||||||
train_new_fingerprint=train_fingerprint,
|
|
||||||
test_new_fingerprint=test_fingerprint,
|
|
||||||
)
|
|
||||||
eval_dataset = ds_w_test_split["test"]
|
|
||||||
train_dataset = ds_w_test_split["train"]
|
|
||||||
|
|
||||||
if not train_is_preprocessed:
|
|
||||||
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
|
|
||||||
if eval_dataset and not eval_is_preprocessed:
|
|
||||||
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
|
|
||||||
|
|
||||||
if cfg.dataset_exact_deduplication:
|
|
||||||
train_dataset, eval_dataset, _ = deduplicate_and_log_datasets(
|
|
||||||
train_dataset=train_dataset, eval_dataset=eval_dataset
|
|
||||||
for dataset_config in datasets_with_name_generator(datasets_configs):
|
for dataset_config in datasets_with_name_generator(datasets_configs):
|
||||||
dataset: Dataset | DatasetDict = load_dataset_with_config(
|
dataset: Dataset | DatasetDict = load_dataset_with_config(
|
||||||
dataset_config, cfg.hf_use_auth_token, streaming=False
|
dataset_config, cfg.hf_use_auth_token, streaming=False
|
||||||
|
|||||||
@@ -424,7 +424,7 @@ class AxolotlInputConfig(
|
|||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
"description": "The maximum length of an input for evaluation. If not specified, defaults to sequence_len"
|
"description": "The maximum length of an input for evaluation. If not specified, defaults to sequence_len"
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
min_sample_len: int | None = None
|
min_sample_len: int | None = None
|
||||||
max_prompt_len: int = Field(
|
max_prompt_len: int = Field(
|
||||||
|
|||||||
Reference in New Issue
Block a user