diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index cb467c8f6..a94270c92 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1456,21 +1456,6 @@ class StreamingValidationMixin: return self - @model_validator(mode="after") - def check_streaming_skip_prepare_dataset(self): - """Ensure skip_prepare_dataset is set for streaming datasets.""" - # Check if streaming is enabled for training datasets - if self._is_streaming_enabled(): - skip_prepare = getattr(self, "skip_prepare_dataset", None) - if skip_prepare is False: - LOG.warning( - "skip_prepare_dataset=False is not compatible with streaming " - "datasets. Setting skip_prepare_dataset=True." - ) - self.skip_prepare_dataset = True - - return self - @model_validator(mode="after") def check_dataset_mixing_weights(self): """Validate dataset mixing weights configuration.""" diff --git a/tests/e2e/test_streaming.py b/tests/e2e/test_streaming.py index c10952c6b..1425fc859 100644 --- a/tests/e2e/test_streaming.py +++ b/tests/e2e/test_streaming.py @@ -85,84 +85,6 @@ class TestStreamingDatasets: "Train Loss (%s) is too high", ) - def test_streaming_eval_specific_mixing(self, temp_dir): - """Test eval-specific mixing strategy override""" - - cfg = DictDefault( - { - "base_model": "HuggingFaceTB/SmolLM2-135M", - "flash_attention": True, - "sequence_len": 512, - "sample_packing": False, - "dataset_processes": 1, - "special_tokens": { - "pad_token": "<|endoftext|>", - }, - "datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - }, - { - "path": "tatsu-lab/alpaca", - "type": "alpaca", - }, - ], - "test_datasets": [ - { - "path": "mhenrichsen/alpaca_2k_test", - "type": "alpaca", - "split": "train", - }, - { - "path": "tatsu-lab/alpaca", - "type": "alpaca", - "split": "train", - }, - ], - # Streaming config - "streaming": True, - "eval_streaming": True, - "max_steps": 3, - # Different mixing for train vs eval - "dataset_mixing_strategy": "round_robin", - "eval_dataset_mixing_strategy": "weighted", - "eval_mixing_weights": [0.6, 0.4], - "micro_batch_size": 1, - "gradient_accumulation_steps": 1, - "output_dir": temp_dir, - "learning_rate": 0.00001, - "optimizer": "adamw_torch_fused", - "lr_scheduler": "cosine", - "save_safetensors": True, - "bf16": "auto", - "use_tensorboard": True, - "save_first_step": False, - "eval_steps": 3, # Eval at the end - } - ) - - cfg = validate_config(cfg) - normalize_config(cfg) - dataset_meta = load_datasets(cfg=cfg) - - train(cfg=cfg, dataset_meta=dataset_meta) - check_model_output_exists(temp_dir, cfg) - - # Check both train and eval losses - check_tensorboard( - temp_dir + "/runs", - "train/train_loss", - 2.5, - "Train Loss (%s) is too high", - ) - check_tensorboard( - temp_dir + "/runs", - "eval/eval_loss", - 2.5, - "Eval Loss (%s) is too high", - ) - def test_streaming_validation_error(self, temp_dir): """Test that pydantic validation catches invalid streaming configs"""