diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 9f5c726ab..22c6a6194 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -163,23 +163,26 @@ def load_prepare_preference_datasets(cfg): # "prompt", "chosen" and "rejected" already preprocessed split_datasets[i] = data_set - drop_long = partial( - drop_long_rl_seq, - rl=_cfg.rl, - tokenizer=tokenizer, - sequence_len=cfg.sequence_len, - ) + if not cfg.skip_prepare_dataset: + drop_long = partial( + drop_long_rl_seq, + rl=_cfg.rl, + tokenizer=tokenizer, + sequence_len=cfg.sequence_len, + ) - prior_len = len(split_datasets[i]) - split_datasets[i] = split_datasets[i].filter( - drop_long, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Long Sequences", - ) - dropped = prior_len - len(split_datasets[i]) - if dropped: - LOG.warning(f"Dropped {dropped} long samples from dataset index {i}") + prior_len = len(split_datasets[i]) + split_datasets[i] = split_datasets[i].filter( + drop_long, + num_proc=cfg.dataset_processes, + load_from_cache_file=not cfg.is_preprocess, + desc="Dropping Long Sequences", + ) + dropped = prior_len - len(split_datasets[i]) + if dropped: + LOG.warning( + f"Dropped {dropped} long samples from dataset index {i}" + ) combined_datasets = concatenate_datasets(split_datasets) combined_datasets = combined_datasets.shuffle(seed=cfg.seed)