diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index b95f86ab2..f8a839b74 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -2,7 +2,7 @@ import inspect from functools import partial -from typing import Any, Callable, Literal +from typing import Any, Callable, Literal, List, Union from datasets import Dataset, DatasetDict from transformers import PreTrainedTokenizer @@ -120,6 +120,13 @@ def _map_dataset( ) return dataset +def drop_long_rl_seq(sample, rl, tokenizer, sequence_len, handling="drop"): + """ + Backward-compatibility wrapper for legacy imports in tests. + Delegates to the new predicate. + """ + return _drop_long_sequences(sample, rl, tokenizer, sequence_len) + def _drop_long_sequences( @@ -260,9 +267,7 @@ def load_prepare_preference_datasets(cfg): """ if rl in (RLType.DPO, RLType.IPO, RLType.ORPO, RLType.SIMPO): if not ( - sample.get("prompt") - and sample.get("chosen") - and sample.get("rejected") + sample.get("prompt") and sample.get("chosen") and sample.get("rejected") ): return False prompt = sample["prompt"] @@ -270,7 +275,9 @@ def load_prepare_preference_datasets(cfg): rejected = sample["rejected"] len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"]) len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"]) - len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"]) + len_rejected = len( + tokenizer(rejected, add_special_tokens=False)["input_ids"] + ) return (len_prompt + len_chosen) <= sequence_len and ( len_prompt + len_rejected ) <= sequence_len @@ -288,6 +295,7 @@ def load_prepare_preference_datasets(cfg): # GRPO does not enforce this check here return True return False + def load_split(dataset_cfgs, _cfg): split_datasets: List[Any] = [] use_auth_token = _cfg.hf_use_auth_token @@ -296,6 +304,8 @@ def load_prepare_preference_datasets(cfg): config_dataset, use_auth_token, streaming=False ) split_datasets.append(ds) + + def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset: """Load and process dataset split for RL training. @@ -309,141 +319,6 @@ def _load_split(cfg: DictDefault, split: Literal["train", "test"]) -> Dataset: datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets split_datasets: list[Dataset | DatasetDict] = [] - map_kwargs = {} - if isinstance(ds_transform_fn, tuple): - ds_transform_fn, map_kwargs = ds_transform_fn - split_datasets[i] = map_dataset( - cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs - ) - elif _cfg.rl is RLType.KTO: - ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i) - map_kwargs = {} - if isinstance(ds_transform_fn, tuple): - ds_transform_fn, map_kwargs = ds_transform_fn - split_datasets[i] = map_dataset( - cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs - ) - else: - # If no `type` is provided, assume the dataset is already in the expected format with - # "prompt", "chosen" and "rejected" already preprocessed - split_datasets[i] = data_set - - if not cfg.skip_prepare_dataset: - # Determine handling mode - # Support legacy alias "excess_token_handling" for compatibility - handling = cfg.get( - "sequence_len_overflow_handling", - cfg.get("excess_token_handling", "drop"), - ) - - drop_long = partial( - drop_long_rl_seq, - rl=_cfg.rl, - tokenizer=tokenizer, - sequence_len=cfg.sequence_len, - handling=handling, # Pass the handling mode - ) - - prior_len = len(split_datasets[i]) - - # Use map for truncate mode and filter for drop mode - if handling == "truncate": - split_datasets[i] = split_datasets[i].map( - drop_long, # Function now returns modified sample or original - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Truncating Long Sequences", - ) - # After truncation, drop any samples that still exceed sequence_len (e.g., prompt alone too long) - split_datasets[i] = split_datasets[i].filter( - partial( - _is_rl_seq_within_sequence_len, - rl=_cfg.rl, - tokenizer=tokenizer, - sequence_len=cfg.sequence_len, - ), - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Oversize Samples After Truncation", - ) - LOG.info( - f"Processed dataset index {i} with truncation handling for sequence length {cfg.sequence_len}" - ) - else: # handling == "drop" - split_datasets[i] = split_datasets[i].filter( - drop_long, # Function now returns boolean - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Long Sequences", - ) - dropped = prior_len - len(split_datasets[i]) - if dropped: - LOG.warning( - f"Dropped {dropped} long samples from dataset index {i}" - ) - - combined_datasets = concatenate_datasets(split_datasets) - combined_datasets = combined_datasets.shuffle(seed=cfg.seed or 42) - - return combined_datasets - - with zero_first(is_main_process()): - train_is_preprocessed = False - eval_is_preprocessed = False - if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets): - train_is_preprocessed = True - else: - train_dataset = load_split(cfg.datasets, cfg) - - eval_dataset = None - if cfg.test_datasets: - if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets): - eval_is_preprocessed = True - else: - eval_dataset = load_split(cfg.test_datasets, cfg) - if not eval_dataset: - if cfg.val_set_size: - seed = cfg.seed if cfg.seed is not None else 42 - - # ensure we end up with the same fingerprint by doing rank0 first and being able to cache - to_hash_train = ( - train_dataset._fingerprint # pylint: disable=protected-access - + "|" - + str(cfg.val_set_size) - + "|" - + "train" - + "|" - + str(seed) - ) - to_hash_test = ( - train_dataset._fingerprint # pylint: disable=protected-access - + "|" - + str(cfg.val_set_size) - + "|" - + "test" - + "|" - + str(seed) - ) - train_fingerprint = md5(to_hash_train) - test_fingerprint = md5(to_hash_test) - ds_w_test_split = train_dataset.train_test_split( - test_size=cfg.val_set_size, - seed=seed, - shuffle=False, - train_new_fingerprint=train_fingerprint, - test_new_fingerprint=test_fingerprint, - ) - eval_dataset = ds_w_test_split["test"] - train_dataset = ds_w_test_split["train"] - - if not train_is_preprocessed: - _save_preprocessed_ds(cfg, cfg.datasets, train_dataset) - if eval_dataset and not eval_is_preprocessed: - _save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset) - - if cfg.dataset_exact_deduplication: - train_dataset, eval_dataset, _ = deduplicate_and_log_datasets( - train_dataset=train_dataset, eval_dataset=eval_dataset for dataset_config in datasets_with_name_generator(datasets_configs): dataset: Dataset | DatasetDict = load_dataset_with_config( dataset_config, cfg.hf_use_auth_token, streaming=False diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 63ffc6547..7085f5591 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -424,7 +424,7 @@ class AxolotlInputConfig( default=None, json_schema_extra={ "description": "The maximum length of an input for evaluation. If not specified, defaults to sequence_len" - } + }, ) min_sample_len: int | None = None max_prompt_len: int = Field(