don't fail during preprocess for sampling from iterable dataset (#2825) [skip ci]

This commit is contained in:
Wing Lian
2025-06-27 10:37:53 -04:00
committed by GitHub
parent 29289a4de9
commit 24f2887e87

View File

@@ -75,6 +75,7 @@ def load_datasets(
num_examples = cli_args.debug_num_examples if cli_args else 1 num_examples = cli_args.debug_num_examples if cli_args else 1
text_only = cli_args.debug_text_only if cli_args else False text_only = cli_args.debug_text_only if cli_args else False
try:
train_samples = sample_dataset(train_dataset, num_examples) train_samples = sample_dataset(train_dataset, num_examples)
check_dataset_labels( check_dataset_labels(
train_samples, train_samples,
@@ -82,6 +83,9 @@ def load_datasets(
num_examples=num_examples, num_examples=num_examples,
text_only=text_only, text_only=text_only,
) )
except AttributeError:
# can't sample iterable datasets
pass
LOG.info("printing prompters...") LOG.info("printing prompters...")
for prompter in prompters: for prompter in prompters: