don't fail during preprocess for sampling from iterable dataset (#2825) [skip ci]

This commit is contained in:
Wing Lian
2025-06-27 10:37:53 -04:00
committed by GitHub
parent 29289a4de9
commit 24f2887e87

View File

@@ -75,13 +75,17 @@ def load_datasets(
num_examples = cli_args.debug_num_examples if cli_args else 1
text_only = cli_args.debug_text_only if cli_args else False
train_samples = sample_dataset(train_dataset, num_examples)
check_dataset_labels(
train_samples,
tokenizer,
num_examples=num_examples,
text_only=text_only,
)
try:
train_samples = sample_dataset(train_dataset, num_examples)
check_dataset_labels(
train_samples,
tokenizer,
num_examples=num_examples,
text_only=text_only,
)
except AttributeError:
# can't sample iterable datasets
pass
LOG.info("printing prompters...")
for prompter in prompters: