diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index aeb55a73b..d8a7174fa 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -145,7 +145,13 @@ def _prepare_standard_dataset( return train_dataset, eval_dataset, -1, prompters # Validate sample packing configuration for evaluation - if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False: + # Skip validation for streaming eval datasets since theWhat hy don't have a calculable length + if ( + eval_dataset + and cfg.sample_packing + and cfg.eval_sample_packing is not False + and not isinstance(eval_dataset, IterableDataset) + ): total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) if total_eval_steps == 0: raise ValueError( @@ -355,7 +361,7 @@ def _load_raw_datasets( cfg: DictDefault, datasets_configs: list, tokenizer: PreTrainedTokenizer, - split: str, + split: Literal["train", "test"], processor: ProcessorMixin | None = None, ) -> tuple[Dataset, list[Prompter | None]]: """Load, process, merge, and save raw datasets.""" @@ -406,13 +412,14 @@ def _load_and_process_single_dataset( dataset_config: DictDefault, cfg: DictDefault, tokenizer: PreTrainedTokenizer, - split: str, + split: Literal["train", "test"], seed: int, processor: ProcessorMixin | None = None, ) -> tuple[Dataset | IterableDataset, Prompter | None]: """Load and process a single dataset based on the passed config.""" + use_streaming_for_split = _is_streaming_enabled_for_split(cfg, split) dataset = load_dataset_with_config( - dataset_config, cfg.hf_use_auth_token, cfg.streaming + dataset_config, cfg.hf_use_auth_token, use_streaming_for_split ) d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type) diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 7a5296ec8..6c4fa0517 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1127,6 +1127,30 @@ class PretrainingValidationMixin: data["accelerator_config"]["dispatch_batches"] = False return data + @model_validator(mode="before") + @classmethod + def check_streaming_split_batches_accelerate(cls, data): + # Check if either training or eval uses streaming + streaming = data.get("streaming", False) + eval_streaming = data.get("eval_streaming") + if eval_streaming is None: + eval_streaming = streaming + + # If either training or eval uses streaming, configure accelerator + if streaming or eval_streaming: + accelerator_config = data.get("accelerator_config", {}) + if not accelerator_config: + data["accelerator_config"] = { + "split_batches": False, + "dispatch_batches": False, + } + else: + if accelerator_config.get("split_batches") is None: + data["accelerator_config"]["split_batches"] = False + if accelerator_config.get("dispatch_batches") is None: + data["accelerator_config"]["dispatch_batches"] = False + return data + class ModelCompatibilityValidationMixin: """Validation methods for specific model compatibility.""" diff --git a/tests/e2e/test_streaming.py b/tests/e2e/test_streaming.py index 6ab077440..e0fec5876 100644 --- a/tests/e2e/test_streaming.py +++ b/tests/e2e/test_streaming.py @@ -79,7 +79,7 @@ class TestStreamingDatasets: check_tensorboard( temp_dir + "/runs", "train/train_loss", - 25.0, # Loss should be reasonable for a smoke test (higher threshold for streaming) + 2.5, # Loss should be reasonable for a smoke test (higher threshold for streaming) "Train Loss (%s) is too high", ) @@ -151,13 +151,13 @@ class TestStreamingDatasets: check_tensorboard( temp_dir + "/runs", "train/train_loss", - 25.0, + 2.5, "Train Loss (%s) is too high", ) check_tensorboard( temp_dir + "/runs", "eval/eval_loss", - 25.0, + 2.5, "Eval Loss (%s) is too high", ) @@ -256,6 +256,6 @@ class TestStreamingDatasets: check_tensorboard( temp_dir + "/runs", "train/train_loss", - 25.0, + 2.5, "Train Loss (%s) is too high", )