This commit is contained in:
Dan Saunders
2025-08-20 17:41:33 +00:00
parent 1b7b67d06e
commit 7eba3795fe
3 changed files with 39 additions and 8 deletions

View File

@@ -145,7 +145,13 @@ def _prepare_standard_dataset(
return train_dataset, eval_dataset, -1, prompters
# Validate sample packing configuration for evaluation
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
# Skip validation for streaming eval datasets since theWhat hy don't have a calculable length
if (
eval_dataset
and cfg.sample_packing
and cfg.eval_sample_packing is not False
and not isinstance(eval_dataset, IterableDataset)
):
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
if total_eval_steps == 0:
raise ValueError(
@@ -355,7 +361,7 @@ def _load_raw_datasets(
cfg: DictDefault,
datasets_configs: list,
tokenizer: PreTrainedTokenizer,
split: str,
split: Literal["train", "test"],
processor: ProcessorMixin | None = None,
) -> tuple[Dataset, list[Prompter | None]]:
"""Load, process, merge, and save raw datasets."""
@@ -406,13 +412,14 @@ def _load_and_process_single_dataset(
dataset_config: DictDefault,
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
split: str,
split: Literal["train", "test"],
seed: int,
processor: ProcessorMixin | None = None,
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Load and process a single dataset based on the passed config."""
use_streaming_for_split = _is_streaming_enabled_for_split(cfg, split)
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, cfg.streaming
dataset_config, cfg.hf_use_auth_token, use_streaming_for_split
)
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)

View File

@@ -1127,6 +1127,30 @@ class PretrainingValidationMixin:
data["accelerator_config"]["dispatch_batches"] = False
return data
@model_validator(mode="before")
@classmethod
def check_streaming_split_batches_accelerate(cls, data):
# Check if either training or eval uses streaming
streaming = data.get("streaming", False)
eval_streaming = data.get("eval_streaming")
if eval_streaming is None:
eval_streaming = streaming
# If either training or eval uses streaming, configure accelerator
if streaming or eval_streaming:
accelerator_config = data.get("accelerator_config", {})
if not accelerator_config:
data["accelerator_config"] = {
"split_batches": False,
"dispatch_batches": False,
}
else:
if accelerator_config.get("split_batches") is None:
data["accelerator_config"]["split_batches"] = False
if accelerator_config.get("dispatch_batches") is None:
data["accelerator_config"]["dispatch_batches"] = False
return data
class ModelCompatibilityValidationMixin:
"""Validation methods for specific model compatibility."""

View File

@@ -79,7 +79,7 @@ class TestStreamingDatasets:
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
25.0, # Loss should be reasonable for a smoke test (higher threshold for streaming)
2.5, # Loss should be reasonable for a smoke test (higher threshold for streaming)
"Train Loss (%s) is too high",
)
@@ -151,13 +151,13 @@ class TestStreamingDatasets:
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
25.0,
2.5,
"Train Loss (%s) is too high",
)
check_tensorboard(
temp_dir + "/runs",
"eval/eval_loss",
25.0,
2.5,
"Eval Loss (%s) is too high",
)
@@ -256,6 +256,6 @@ class TestStreamingDatasets:
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
25.0,
2.5,
"Train Loss (%s) is too high",
)