add support for passing map kwargs to dataset map in rl
This commit is contained in:
@@ -57,7 +57,7 @@ def _save_preprocessed_ds(cfg, sub_cfg, dataset):
|
||||
dataset.save_to_disk(str(prepared_ds_path))
|
||||
|
||||
|
||||
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
|
||||
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
||||
sig = inspect.signature(ds_transform_fn)
|
||||
if "tokenizer" in sig.parameters:
|
||||
if not tokenizer:
|
||||
@@ -70,6 +70,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
|
||||
data_set = data_set.map(
|
||||
ds_transform_fn,
|
||||
desc="Mapping RL Dataset",
|
||||
**map_kwargs,
|
||||
)
|
||||
|
||||
return data_set
|
||||
@@ -150,13 +151,19 @@ def load_prepare_preference_datasets(cfg):
|
||||
else:
|
||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||
|
||||
map_kwargs = {}
|
||||
if isinstance(ds_transform_fn, tuple):
|
||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||
split_datasets[i] = map_dataset(
|
||||
cfg, data_set, ds_transform_fn, tokenizer
|
||||
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||
)
|
||||
elif _cfg.rl == "kto":
|
||||
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
||||
map_kwargs = {}
|
||||
if isinstance(ds_transform_fn, tuple):
|
||||
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||
split_datasets[i] = map_dataset(
|
||||
cfg, data_set, ds_transform_fn, tokenizer
|
||||
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||
)
|
||||
else:
|
||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||
|
||||
Reference in New Issue
Block a user