add support for passing map kwargs to dataset map in rl

This commit is contained in:
Wing Lian
2025-02-03 09:04:30 -05:00
parent 1e94d7ef65
commit 3c7517fd55

View File

@@ -57,7 +57,7 @@ def _save_preprocessed_ds(cfg, sub_cfg, dataset):
dataset.save_to_disk(str(prepared_ds_path)) dataset.save_to_disk(str(prepared_ds_path))
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer): def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
sig = inspect.signature(ds_transform_fn) sig = inspect.signature(ds_transform_fn)
if "tokenizer" in sig.parameters: if "tokenizer" in sig.parameters:
if not tokenizer: if not tokenizer:
@@ -70,6 +70,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
data_set = data_set.map( data_set = data_set.map(
ds_transform_fn, ds_transform_fn,
desc="Mapping RL Dataset", desc="Mapping RL Dataset",
**map_kwargs,
) )
return data_set return data_set
@@ -150,13 +151,19 @@ def load_prepare_preference_datasets(cfg):
else: else:
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
split_datasets[i] = map_dataset( split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
) )
elif _cfg.rl == "kto": elif _cfg.rl == "kto":
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i) ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
map_kwargs = {}
if isinstance(ds_transform_fn, tuple):
ds_transform_fn, map_kwargs = ds_transform_fn
split_datasets[i] = map_dataset( split_datasets[i] = map_dataset(
cfg, data_set, ds_transform_fn, tokenizer cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
) )
else: else:
# If no `type` is provided, assume the dataset is already in the expected format with # If no `type` is provided, assume the dataset is already in the expected format with