Correctly handle splits for datasets.arrow_dataset.Dataset objects (#1504)
* Correctly handle splits for datasets.arrow_dataset.Dataset objects The `load_tokenized_prepared_datasets` function currently has logic for loading a dataset from local path that always checks if a split is in the dataset. The problem is, if the dataset is loaded using `load_from_disk` and it is an Arrow-based dataset, *there is no* split information. Instead what happens is, by calling `split in ds`, it presumably searches through all the rows and columns of the arrow dataset object to find e.g., 'train' assuming `split == 'train'`. This causes the program to hang. See https://chat.openai.com/share/0d567dbd-d60b-4079-9040-e1de58a4dff3 for context. * chore: lint --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -379,14 +379,15 @@ def load_tokenized_prepared_datasets(
|
||||
d_base_type = d_type_split[0]
|
||||
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
||||
|
||||
if config_dataset.split and config_dataset.split in ds:
|
||||
ds = ds[config_dataset.split]
|
||||
elif split in ds:
|
||||
ds = ds[split]
|
||||
elif isinstance(ds, DatasetDict):
|
||||
raise ValueError(
|
||||
f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
|
||||
)
|
||||
if isinstance(ds, DatasetDict):
|
||||
if config_dataset.split and config_dataset.split in ds:
|
||||
ds = ds[config_dataset.split]
|
||||
elif split in ds:
|
||||
ds = ds[split]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"no {split} split found for dataset {config_dataset.path}, you may specify a split with 'split: `"
|
||||
)
|
||||
|
||||
# support for using a subset of the data
|
||||
if config_dataset.shards:
|
||||
|
||||
Reference in New Issue
Block a user