honor skip prepare for rl

This commit is contained in:
Wing Lian
2025-02-02 23:39:10 -05:00
parent 54b0d3d0e8
commit 704ddd6ff1

View File

@@ -163,23 +163,26 @@ def load_prepare_preference_datasets(cfg):
# "prompt", "chosen" and "rejected" already preprocessed # "prompt", "chosen" and "rejected" already preprocessed
split_datasets[i] = data_set split_datasets[i] = data_set
drop_long = partial( if not cfg.skip_prepare_dataset:
drop_long_rl_seq, drop_long = partial(
rl=_cfg.rl, drop_long_rl_seq,
tokenizer=tokenizer, rl=_cfg.rl,
sequence_len=cfg.sequence_len, tokenizer=tokenizer,
) sequence_len=cfg.sequence_len,
)
prior_len = len(split_datasets[i]) prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter( split_datasets[i] = split_datasets[i].filter(
drop_long, drop_long,
num_proc=cfg.dataset_processes, num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess, load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences", desc="Dropping Long Sequences",
) )
dropped = prior_len - len(split_datasets[i]) dropped = prior_len - len(split_datasets[i])
if dropped: if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}") LOG.warning(
f"Dropped {dropped} long samples from dataset index {i}"
)
combined_datasets = concatenate_datasets(split_datasets) combined_datasets = concatenate_datasets(split_datasets)
combined_datasets = combined_datasets.shuffle(seed=cfg.seed) combined_datasets = combined_datasets.shuffle(seed=cfg.seed)