don't fail during preprocess for sampling from iterable dataset (#2825) [skip ci]

This commit is contained in:
Wing Lian
2025-06-27 10:37:53 -04:00
committed by GitHub
parent 29289a4de9
commit 24f2887e87

View File

@@ -75,13 +75,17 @@ def load_datasets(
num_examples = cli_args.debug_num_examples if cli_args else 1 num_examples = cli_args.debug_num_examples if cli_args else 1
text_only = cli_args.debug_text_only if cli_args else False text_only = cli_args.debug_text_only if cli_args else False
train_samples = sample_dataset(train_dataset, num_examples) try:
check_dataset_labels( train_samples = sample_dataset(train_dataset, num_examples)
train_samples, check_dataset_labels(
tokenizer, train_samples,
num_examples=num_examples, tokenizer,
text_only=text_only, num_examples=num_examples,
) text_only=text_only,
)
except AttributeError:
# can't sample iterable datasets
pass
LOG.info("printing prompters...") LOG.info("printing prompters...")
for prompter in prompters: for prompter in prompters: