From 2176962231f0ce1d5e05e020e98b71dda088394f Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 20 Aug 2025 05:17:05 +0000 Subject: [PATCH] separate out train and eval dataset streaming --- src/axolotl/utils/data/sft.py | 52 +++++++++- src/axolotl/utils/data/shared.py | 5 +- src/axolotl/utils/schemas/config.py | 20 +++- src/axolotl/utils/schemas/validation.py | 122 +++++++++++++----------- 4 files changed, 135 insertions(+), 64 deletions(-) diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 4c51aa2d1..8fb4ea63d 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -44,6 +44,48 @@ from axolotl.utils.trainer import ( LOG = get_logger(__name__) +def _is_streaming_enabled_for_split( + cfg: DictDefault, split: Literal["train", "test"] +) -> bool: + """Check if streaming is enabled for a specific split.""" + if split == "test": + # For eval datasets, check eval_streaming first, then fall back to streaming + eval_streaming = cfg.get("eval_streaming") + if eval_streaming is not None: + return eval_streaming + + # Fall back to main streaming setting + streaming = cfg.get("streaming") + if streaming is True: + return True + + # Check if pretraining dataset exists (defaults to streaming) + has_pretraining = cfg.get("pretraining_dataset") is not None + streaming_default_for_pretraining = has_pretraining and streaming is None + + return streaming_default_for_pretraining + + +def _get_streaming_config_for_split( + cfg: DictDefault, split: Literal["train", "test"] +) -> DictDefault: + """Get a modified config object with split-specific streaming settings.""" + if split != "test": + return cfg + + # Override with eval-specific configs if they exist + streaming_cfg = DictDefault(cfg) + eval_strategy = cfg.get("eval_streaming_dataset_mixing_strategy") + eval_weights = cfg.get("eval_streaming_mixing_weights") + + if eval_strategy is not None: + streaming_cfg["streaming_dataset_mixing_strategy"] = eval_strategy + if eval_weights is not None: + streaming_cfg["streaming_mixing_weights"] = eval_weights + + return streaming_cfg + + @retry_on_request_exceptions(max_retries=3, delay=5) def prepare_datasets( cfg: DictDefault, @@ -267,10 +309,14 @@ def _load_tokenized_prepared_datasets( datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets prompters: list[Prompter | None] = [] - # For streaming datasets, skip caching and load raw datasets directly - if cfg.streaming: + # Check if streaming is enabled for this split + use_streaming = _is_streaming_enabled_for_split(cfg, split) + + if use_streaming: + # For streaming datasets, skip caching and load raw datasets directly + streaming_cfg = _get_streaming_config_for_split(cfg, split) dataset, prompters = _load_raw_datasets( - cfg, + streaming_cfg, datasets_configs, tokenizer, split, diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index 3cbb46674..9da658675 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -550,10 +550,7 @@ def merge_datasets( return ds return ds.shuffle(seed=cfg.seed) - # Check if we have any IterableDatasets - has_iterable = any(isinstance(ds, IterableDataset) for ds in datasets) - - if has_iterable: + if any(isinstance(ds, IterableDataset) for ds in datasets): LOG.info("Merging streaming datasets...") merged_dataset = _merge_streaming_datasets(datasets, cfg) else: diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 4b1b7fbf3..c335e6c73 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -935,7 +935,13 @@ class AxolotlInputConfig( streaming: bool | None = Field( default=None, json_schema_extra={ - "description": "Whether to use streaming datasets (IterableDataset) for processing large datasets that don't fit in memory. When True, data is loaded on-demand during training without upfront preprocessing. Requires max_steps to be set. Pre-training datasets default to streaming unless explicitly set to False." + "description": "Whether to use streaming datasets (IterableDataset) for training datasets. When True, data is loaded on-demand during training without upfront preprocessing. Requires max_steps to be set. Pre-training datasets default to streaming unless explicitly set to False." + }, + ) + eval_streaming: bool | None = Field( + default=None, + json_schema_extra={ + "description": "Whether to use streaming datasets for evaluation datasets. If not set, falls back to the 'streaming' setting. Useful for streaming large training data while keeping smaller eval datasets in memory." }, ) streaming_dataset_mixing_strategy: str | None = Field( @@ -950,6 +956,18 @@ class AxolotlInputConfig( "description": "Weights for weighted mixing strategy when using multiple streaming datasets. Must sum to 1.0 and have same length as datasets list. Only used when streaming_dataset_mixing_strategy='weighted'." }, ) + eval_streaming_dataset_mixing_strategy: str | None = Field( + default=None, + json_schema_extra={ + "description": "Strategy for mixing multiple streaming evaluation datasets. If not set, falls back to streaming_dataset_mixing_strategy. Options: 'round_robin', 'weighted', 'random'." + }, + ) + eval_streaming_mixing_weights: list[float] | None = Field( + default=None, + json_schema_extra={ + "description": "Weights for weighted mixing strategy for evaluation datasets. Must sum to 1.0 and have same length as evaluation datasets list." + }, + ) # INTERNALS - document for now, generally not set externally is_preprocess: bool | None = None diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 42993c55c..5ec85709a 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1388,20 +1388,29 @@ class GRPOVllmValidationMixin: class StreamingValidationMixin: """Validation methods related to streaming datasets.""" - @model_validator(mode="after") - def check_streaming_requires_max_steps(self): - """Ensure max_steps is set when using streaming datasets.""" - # Check if streaming is explicitly enabled - streaming_enabled = getattr(self, "streaming", None) is True + def _is_streaming_enabled(self, context: str = "train") -> bool: + """Check if streaming is enabled for a given context (train or eval).""" + if context == "eval": + eval_streaming = getattr(self, "eval_streaming", None) + if eval_streaming is not None: + return eval_streaming + + # Fall back to main streaming setting + streaming = getattr(self, "streaming", None) + if streaming is True: + return True # Check if pretraining dataset exists (defaults to streaming) has_pretraining = getattr(self, "pretraining_dataset", None) is not None - streaming_default_for_pretraining = ( - has_pretraining and getattr(self, "streaming", None) is None - ) + streaming_default_for_pretraining = has_pretraining and streaming is None - # If streaming is enabled (explicitly or by default for pretraining) - if streaming_enabled or streaming_default_for_pretraining: + return streaming_default_for_pretraining + + @model_validator(mode="after") + def check_streaming_requires_max_steps(self): + """Ensure max_steps is set when using streaming datasets.""" + # Check if streaming is enabled for training datasets + if self._is_streaming_enabled("train"): max_steps = getattr(self, "max_steps", None) if not max_steps: raise ValueError("max_steps must be set when using streaming datasets") @@ -1411,17 +1420,8 @@ class StreamingValidationMixin: @model_validator(mode="after") def check_streaming_validation_splits_conflict(self): """Ensure validation splits are not used with streaming datasets.""" - # Check if streaming is explicitly enabled - streaming_enabled = getattr(self, "streaming", None) is True - - # Check if pretraining dataset exists (defaults to streaming) - has_pretraining = getattr(self, "pretraining_dataset", None) is not None - streaming_default_for_pretraining = ( - has_pretraining and getattr(self, "streaming", None) is None - ) - - # If streaming is enabled (explicitly or by default for pretraining) - if streaming_enabled or streaming_default_for_pretraining: + # Check if streaming is enabled for training datasets + if self._is_streaming_enabled("train"): val_set_size = getattr(self, "val_set_size", 0.0) if val_set_size and val_set_size > 0: raise ValueError( @@ -1433,17 +1433,8 @@ class StreamingValidationMixin: @model_validator(mode="after") def check_streaming_preprocessing_conflict(self): """Ensure preprocessing is not enabled with streaming datasets.""" - # Check if streaming is explicitly enabled - streaming_enabled = getattr(self, "streaming", None) is True - - # Check if pretraining dataset exists (defaults to streaming) - has_pretraining = getattr(self, "pretraining_dataset", None) is not None - streaming_default_for_pretraining = ( - has_pretraining and getattr(self, "streaming", None) is None - ) - - # If streaming is enabled (explicitly or by default for pretraining) - if streaming_enabled or streaming_default_for_pretraining: + # Check if streaming is enabled for training or eval datasets + if self._is_streaming_enabled("train") or self._is_streaming_enabled("eval"): if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1": raise ValueError("preprocess is not supported for streaming datasets") @@ -1452,17 +1443,8 @@ class StreamingValidationMixin: @model_validator(mode="after") def check_streaming_skip_prepare_dataset(self): """Ensure skip_prepare_dataset is set for streaming datasets.""" - # Check if streaming is explicitly enabled - streaming_enabled = getattr(self, "streaming", None) is True - - # Check if pretraining dataset exists (defaults to streaming) - has_pretraining = getattr(self, "pretraining_dataset", None) is not None - streaming_default_for_pretraining = ( - has_pretraining and getattr(self, "streaming", None) is None - ) - - # If streaming is enabled (explicitly or by default for pretraining) - if streaming_enabled or streaming_default_for_pretraining: + # Check if streaming is enabled for training or eval datasets + if self._is_streaming_enabled("train") or self._is_streaming_enabled("eval"): skip_prepare = getattr(self, "skip_prepare_dataset", None) if skip_prepare is False: LOG.warning( @@ -1476,45 +1458,73 @@ class StreamingValidationMixin: @model_validator(mode="after") def check_streaming_mixing_weights(self): """Validate streaming_mixing_weights configuration.""" + valid_strategies = ["round_robin", "weighted", "random"] + + # Check main strategy and weights strategy = getattr(self, "streaming_dataset_mixing_strategy", "round_robin") weights = getattr(self, "streaming_mixing_weights", None) + self._validate_streaming_strategy_and_weights( + strategy, + weights, + "streaming_dataset_mixing_strategy", + "streaming_mixing_weights", + valid_strategies, + ) - # Validate strategy values - valid_strategies = ["round_robin", "weighted", "random"] + # Check eval-specific strategy and weights + eval_strategy = getattr(self, "eval_streaming_dataset_mixing_strategy", None) + eval_weights = getattr(self, "eval_streaming_mixing_weights", None) + + if eval_strategy is not None: + self._validate_streaming_strategy_and_weights( + eval_strategy, + eval_weights, + "eval_streaming_dataset_mixing_strategy", + "eval_streaming_mixing_weights", + valid_strategies, + ) + elif eval_weights is not None: + LOG.warning( + "eval_streaming_mixing_weights provided but eval_streaming_dataset_mixing_strategy is not set. " + "Weights will be ignored unless eval_streaming_dataset_mixing_strategy='weighted'." + ) + + return self + + def _validate_streaming_strategy_and_weights( + self, strategy, weights, strategy_field, weights_field, valid_strategies + ): + """Helper method to validate strategy and weights pair.""" if strategy not in valid_strategies: raise ValueError( - f"streaming_dataset_mixing_strategy must be one of {valid_strategies}, " + f"{strategy_field} must be one of {valid_strategies}, " f"got '{strategy}'" ) if strategy == "weighted": if weights is None: raise ValueError( - "streaming_mixing_weights must be provided when " - "streaming_dataset_mixing_strategy='weighted'" + f"{weights_field} must be provided when " + f"{strategy_field}='weighted'" ) if not isinstance(weights, list) or not all( isinstance(w, (int, float)) for w in weights ): - raise ValueError("streaming_mixing_weights must be a list of numbers") + raise ValueError(f"{weights_field} must be a list of numbers") if any(w < 0 for w in weights): - raise ValueError("streaming_mixing_weights must be non-negative") + raise ValueError(f"{weights_field} must be non-negative") if abs(sum(weights) - 1.0) > 1e-6: - raise ValueError( - f"streaming_mixing_weights must sum to 1.0, got {sum(weights)}" - ) + raise ValueError(f"{weights_field} must sum to 1.0, got {sum(weights)}") elif weights is not None and strategy != "weighted": LOG.warning( - f"streaming_mixing_weights provided but strategy is '{strategy}'. " + f"{weights_field} provided but {strategy_field} is '{strategy}'. " "Weights will be ignored." ) - return self - # pylint: disable=too-many-ancestors class ValidationMixin(