fixes
This commit is contained in:
@@ -145,7 +145,13 @@ def _prepare_standard_dataset(
|
|||||||
return train_dataset, eval_dataset, -1, prompters
|
return train_dataset, eval_dataset, -1, prompters
|
||||||
|
|
||||||
# Validate sample packing configuration for evaluation
|
# Validate sample packing configuration for evaluation
|
||||||
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
# Skip validation for streaming eval datasets since theWhat hy don't have a calculable length
|
||||||
|
if (
|
||||||
|
eval_dataset
|
||||||
|
and cfg.sample_packing
|
||||||
|
and cfg.eval_sample_packing is not False
|
||||||
|
and not isinstance(eval_dataset, IterableDataset)
|
||||||
|
):
|
||||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||||
if total_eval_steps == 0:
|
if total_eval_steps == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -355,7 +361,7 @@ def _load_raw_datasets(
|
|||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
datasets_configs: list,
|
datasets_configs: list,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
split: str,
|
split: Literal["train", "test"],
|
||||||
processor: ProcessorMixin | None = None,
|
processor: ProcessorMixin | None = None,
|
||||||
) -> tuple[Dataset, list[Prompter | None]]:
|
) -> tuple[Dataset, list[Prompter | None]]:
|
||||||
"""Load, process, merge, and save raw datasets."""
|
"""Load, process, merge, and save raw datasets."""
|
||||||
@@ -406,13 +412,14 @@ def _load_and_process_single_dataset(
|
|||||||
dataset_config: DictDefault,
|
dataset_config: DictDefault,
|
||||||
cfg: DictDefault,
|
cfg: DictDefault,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
split: str,
|
split: Literal["train", "test"],
|
||||||
seed: int,
|
seed: int,
|
||||||
processor: ProcessorMixin | None = None,
|
processor: ProcessorMixin | None = None,
|
||||||
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
) -> tuple[Dataset | IterableDataset, Prompter | None]:
|
||||||
"""Load and process a single dataset based on the passed config."""
|
"""Load and process a single dataset based on the passed config."""
|
||||||
|
use_streaming_for_split = _is_streaming_enabled_for_split(cfg, split)
|
||||||
dataset = load_dataset_with_config(
|
dataset = load_dataset_with_config(
|
||||||
dataset_config, cfg.hf_use_auth_token, cfg.streaming
|
dataset_config, cfg.hf_use_auth_token, use_streaming_for_split
|
||||||
)
|
)
|
||||||
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
|
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
|
||||||
|
|
||||||
|
|||||||
@@ -1127,6 +1127,30 @@ class PretrainingValidationMixin:
|
|||||||
data["accelerator_config"]["dispatch_batches"] = False
|
data["accelerator_config"]["dispatch_batches"] = False
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_streaming_split_batches_accelerate(cls, data):
|
||||||
|
# Check if either training or eval uses streaming
|
||||||
|
streaming = data.get("streaming", False)
|
||||||
|
eval_streaming = data.get("eval_streaming")
|
||||||
|
if eval_streaming is None:
|
||||||
|
eval_streaming = streaming
|
||||||
|
|
||||||
|
# If either training or eval uses streaming, configure accelerator
|
||||||
|
if streaming or eval_streaming:
|
||||||
|
accelerator_config = data.get("accelerator_config", {})
|
||||||
|
if not accelerator_config:
|
||||||
|
data["accelerator_config"] = {
|
||||||
|
"split_batches": False,
|
||||||
|
"dispatch_batches": False,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
if accelerator_config.get("split_batches") is None:
|
||||||
|
data["accelerator_config"]["split_batches"] = False
|
||||||
|
if accelerator_config.get("dispatch_batches") is None:
|
||||||
|
data["accelerator_config"]["dispatch_batches"] = False
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class ModelCompatibilityValidationMixin:
|
class ModelCompatibilityValidationMixin:
|
||||||
"""Validation methods for specific model compatibility."""
|
"""Validation methods for specific model compatibility."""
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ class TestStreamingDatasets:
|
|||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs",
|
temp_dir + "/runs",
|
||||||
"train/train_loss",
|
"train/train_loss",
|
||||||
25.0, # Loss should be reasonable for a smoke test (higher threshold for streaming)
|
2.5, # Loss should be reasonable for a smoke test (higher threshold for streaming)
|
||||||
"Train Loss (%s) is too high",
|
"Train Loss (%s) is too high",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -151,13 +151,13 @@ class TestStreamingDatasets:
|
|||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs",
|
temp_dir + "/runs",
|
||||||
"train/train_loss",
|
"train/train_loss",
|
||||||
25.0,
|
2.5,
|
||||||
"Train Loss (%s) is too high",
|
"Train Loss (%s) is too high",
|
||||||
)
|
)
|
||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs",
|
temp_dir + "/runs",
|
||||||
"eval/eval_loss",
|
"eval/eval_loss",
|
||||||
25.0,
|
2.5,
|
||||||
"Eval Loss (%s) is too high",
|
"Eval Loss (%s) is too high",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -256,6 +256,6 @@ class TestStreamingDatasets:
|
|||||||
check_tensorboard(
|
check_tensorboard(
|
||||||
temp_dir + "/runs",
|
temp_dir + "/runs",
|
||||||
"train/train_loss",
|
"train/train_loss",
|
||||||
25.0,
|
2.5,
|
||||||
"Train Loss (%s) is too high",
|
"Train Loss (%s) is too high",
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user