fixes
This commit is contained in:
@@ -145,7 +145,13 @@ def _prepare_standard_dataset(
|
||||
return train_dataset, eval_dataset, -1, prompters
|
||||
|
||||
# Validate sample packing configuration for evaluation
|
||||
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
||||
# Skip validation for streaming eval datasets since theWhat hy don't have a calculable length
|
||||
if (
|
||||
eval_dataset
|
||||
and cfg.sample_packing
|
||||
and cfg.eval_sample_packing is not False
|
||||
and not isinstance(eval_dataset, IterableDataset)
|
||||
):
|
||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||
if total_eval_steps == 0:
|
||||
raise ValueError(
|
||||
@@ -355,7 +361,7 @@ def _load_raw_datasets(
|
||||
cfg: DictDefault,
|
||||
datasets_configs: list,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
split: str,
|
||||
split: Literal["train", "test"],
|
||||
processor: ProcessorMixin | None = None,
|
||||
) -> tuple[Dataset, list[Prompter | None]]:
|
||||
"""Load, process, merge, and save raw datasets."""
|
||||
@@ -406,13 +412,14 @@ def _load_and_process_single_dataset(
|
||||
dataset_config: DictDefault,
|
||||
cfg: DictDefault,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
split: str,
|
||||
split: Literal["train", "test"],
|
||||
seed: int,
|
||||
processor: ProcessorMixin | None = None,
|
||||
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
||||
"""Load and process a single dataset based on the passed config."""
|
||||
use_streaming_for_split = _is_streaming_enabled_for_split(cfg, split)
|
||||
dataset = load_dataset_with_config(
|
||||
dataset_config, cfg.hf_use_auth_token, cfg.streaming
|
||||
dataset_config, cfg.hf_use_auth_token, use_streaming_for_split
|
||||
)
|
||||
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
|
||||
|
||||
|
||||
@@ -1127,6 +1127,30 @@ class PretrainingValidationMixin:
|
||||
data["accelerator_config"]["dispatch_batches"] = False
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_streaming_split_batches_accelerate(cls, data):
|
||||
# Check if either training or eval uses streaming
|
||||
streaming = data.get("streaming", False)
|
||||
eval_streaming = data.get("eval_streaming")
|
||||
if eval_streaming is None:
|
||||
eval_streaming = streaming
|
||||
|
||||
# If either training or eval uses streaming, configure accelerator
|
||||
if streaming or eval_streaming:
|
||||
accelerator_config = data.get("accelerator_config", {})
|
||||
if not accelerator_config:
|
||||
data["accelerator_config"] = {
|
||||
"split_batches": False,
|
||||
"dispatch_batches": False,
|
||||
}
|
||||
else:
|
||||
if accelerator_config.get("split_batches") is None:
|
||||
data["accelerator_config"]["split_batches"] = False
|
||||
if accelerator_config.get("dispatch_batches") is None:
|
||||
data["accelerator_config"]["dispatch_batches"] = False
|
||||
return data
|
||||
|
||||
|
||||
class ModelCompatibilityValidationMixin:
|
||||
"""Validation methods for specific model compatibility."""
|
||||
|
||||
@@ -79,7 +79,7 @@ class TestStreamingDatasets:
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs",
|
||||
"train/train_loss",
|
||||
25.0, # Loss should be reasonable for a smoke test (higher threshold for streaming)
|
||||
2.5, # Loss should be reasonable for a smoke test (higher threshold for streaming)
|
||||
"Train Loss (%s) is too high",
|
||||
)
|
||||
|
||||
@@ -151,13 +151,13 @@ class TestStreamingDatasets:
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs",
|
||||
"train/train_loss",
|
||||
25.0,
|
||||
2.5,
|
||||
"Train Loss (%s) is too high",
|
||||
)
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs",
|
||||
"eval/eval_loss",
|
||||
25.0,
|
||||
2.5,
|
||||
"Eval Loss (%s) is too high",
|
||||
)
|
||||
|
||||
@@ -256,6 +256,6 @@ class TestStreamingDatasets:
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs",
|
||||
"train/train_loss",
|
||||
25.0,
|
||||
2.5,
|
||||
"Train Loss (%s) is too high",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user