add support for passing map kwargs to dataset map in rl
This commit is contained in:
@@ -57,7 +57,7 @@ def _save_preprocessed_ds(cfg, sub_cfg, dataset):
|
|||||||
dataset.save_to_disk(str(prepared_ds_path))
|
dataset.save_to_disk(str(prepared_ds_path))
|
||||||
|
|
||||||
|
|
||||||
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
|
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs):
|
||||||
sig = inspect.signature(ds_transform_fn)
|
sig = inspect.signature(ds_transform_fn)
|
||||||
if "tokenizer" in sig.parameters:
|
if "tokenizer" in sig.parameters:
|
||||||
if not tokenizer:
|
if not tokenizer:
|
||||||
@@ -70,6 +70,7 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
|
|||||||
data_set = data_set.map(
|
data_set = data_set.map(
|
||||||
ds_transform_fn,
|
ds_transform_fn,
|
||||||
desc="Mapping RL Dataset",
|
desc="Mapping RL Dataset",
|
||||||
|
**map_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return data_set
|
return data_set
|
||||||
@@ -150,13 +151,19 @@ def load_prepare_preference_datasets(cfg):
|
|||||||
else:
|
else:
|
||||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||||
|
|
||||||
|
map_kwargs = {}
|
||||||
|
if isinstance(ds_transform_fn, tuple):
|
||||||
|
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||||
split_datasets[i] = map_dataset(
|
split_datasets[i] = map_dataset(
|
||||||
cfg, data_set, ds_transform_fn, tokenizer
|
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||||
)
|
)
|
||||||
elif _cfg.rl == "kto":
|
elif _cfg.rl == "kto":
|
||||||
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
||||||
|
map_kwargs = {}
|
||||||
|
if isinstance(ds_transform_fn, tuple):
|
||||||
|
ds_transform_fn, map_kwargs = ds_transform_fn
|
||||||
split_datasets[i] = map_dataset(
|
split_datasets[i] = map_dataset(
|
||||||
cfg, data_set, ds_transform_fn, tokenizer
|
cfg, data_set, ds_transform_fn, tokenizer, **map_kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||||
|
|||||||
Reference in New Issue
Block a user