diff --git a/src/axolotl/utils/data/rl.py b/src/axolotl/utils/data/rl.py index 22c6a6194..3452628f0 100644 --- a/src/axolotl/utils/data/rl.py +++ b/src/axolotl/utils/data/rl.py @@ -57,7 +57,7 @@ def _save_preprocessed_ds(cfg, sub_cfg, dataset): dataset.save_to_disk(str(prepared_ds_path)) -def map_dataset(cfg, data_set, ds_transform_fn, tokenizer): +def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs): sig = inspect.signature(ds_transform_fn) if "tokenizer" in sig.parameters: if not tokenizer: @@ -70,6 +70,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer): data_set = data_set.map( ds_transform_fn, desc="Mapping RL Dataset", + **map_kwargs, ) return data_set @@ -150,13 +151,19 @@ def load_prepare_preference_datasets(cfg): else: ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) + map_kwargs = {} + if isinstance(ds_transform_fn, tuple): + ds_transform_fn, map_kwargs = ds_transform_fn split_datasets[i] = map_dataset( - cfg, data_set, ds_transform_fn, tokenizer + cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs ) elif _cfg.rl == "kto": ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i) + map_kwargs = {} + if isinstance(ds_transform_fn, tuple): + ds_transform_fn, map_kwargs = ds_transform_fn split_datasets[i] = map_dataset( - cfg, data_set, ds_transform_fn, tokenizer + cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs ) else: # If no `type` is provided, assume the dataset is already in the expected format with