From 10335d5df99eabe13a53e434fcbe973debe12305 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 20 Aug 2025 04:44:07 +0000 Subject: [PATCH] add multidata strats --- src/axolotl/utils/data/shared.py | 30 +++++++++++++++++- src/axolotl/utils/data/utils.py | 12 +++---- src/axolotl/utils/schemas/config.py | 9 ------ src/axolotl/utils/schemas/validation.py | 42 +++++++++++++++++++++++++ 4 files changed, 76 insertions(+), 17 deletions(-) diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index bde9e94e1..3cbb46674 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -555,7 +555,7 @@ def merge_datasets( if has_iterable: LOG.info("Merging streaming datasets...") - merged_dataset = interleave_datasets(datasets, seed=cfg.seed) + merged_dataset = _merge_streaming_datasets(datasets, cfg) else: # If enabled, shuffle each dataset independently before merging. # This allows curriculum learning strategies to be applied at the dataset level. @@ -581,3 +581,31 @@ def merge_datasets( LOG.debug("Not shuffling merged datasets.") return merged_dataset + + +def _merge_streaming_datasets( + datasets: list[Dataset | IterableDataset], cfg: DictDefault +) -> IterableDataset: + """Merge streaming datasets using the configured mixing strategy. + + Args: + datasets: List of datasets to merge (at least one must be IterableDataset). + cfg: Configuration object containing streaming mixing settings. + + Returns: + Merged IterableDataset. + """ + # Get mixing configuration + strategy = cfg.get("streaming_dataset_mixing_strategy", "round_robin") + weights = cfg.get("streaming_mixing_weights", None) + + LOG.info(f"Using streaming mixing strategy: {strategy}") + + if strategy == "round_robin": + return interleave_datasets(datasets, seed=cfg.seed) + if strategy == "weighted": + return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed) + + return interleave_datasets( + datasets, probabilities=[1.0 / len(datasets)] * len(datasets), seed=cfg.seed + ) diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index d5fa54196..9285ea366 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -193,15 +193,13 @@ def handle_long_seq_in_dataset( if hasattr(dataset, "column_names") and dataset.column_names: if "input_ids" not in dataset.column_names: LOG.warning( - "Dataset does not contain 'input_ids' column. Skip drop long seq. This is " - "expected for reward modeling." + "Dataset does not contain 'input_ids' column. Skip drop long seq. This " + "is expected for reward modeling." ) return dataset - else: - # For IterableDataset, we can't check columns upfront, so skip for streaming - if isinstance(dataset, IterableDataset): - LOG.info("Skipping drop_long_seq for streaming datasets (not compatible)") - return dataset + elif isinstance(dataset, IterableDataset): + LOG.info("Skipping drop_long_seq for streaming datasets (not compatible)") + return dataset drop_long = functools.partial( drop_long_seq, diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 882485af8..4b1b7fbf3 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -938,14 +938,12 @@ class AxolotlInputConfig( "description": "Whether to use streaming datasets (IterableDataset) for processing large datasets that don't fit in memory. When True, data is loaded on-demand during training without upfront preprocessing. Requires max_steps to be set. Pre-training datasets default to streaming unless explicitly set to False." }, ) - streaming_dataset_mixing_strategy: str | None = Field( default="round_robin", json_schema_extra={ "description": "Strategy for mixing multiple streaming datasets: 'round_robin' (equal sampling), 'weighted' (use streaming_mixing_weights), or 'random' (random sampling with equal probability)." }, ) - streaming_mixing_weights: list[float] | None = Field( default=None, json_schema_extra={ @@ -953,13 +951,6 @@ class AxolotlInputConfig( }, ) - streaming_buffer_per_dataset: int | None = Field( - default=1000, - json_schema_extra={ - "description": "Buffer size per dataset when mixing multiple streaming datasets. Higher values may improve mixing quality but use more memory." - }, - ) - # INTERNALS - document for now, generally not set externally is_preprocess: bool | None = None diff --git a/src/axolotl/utils/schemas/validation.py b/src/axolotl/utils/schemas/validation.py index 64fca363c..42993c55c 100644 --- a/src/axolotl/utils/schemas/validation.py +++ b/src/axolotl/utils/schemas/validation.py @@ -1473,6 +1473,48 @@ class StreamingValidationMixin: return self + @model_validator(mode="after") + def check_streaming_mixing_weights(self): + """Validate streaming_mixing_weights configuration.""" + strategy = getattr(self, "streaming_dataset_mixing_strategy", "round_robin") + weights = getattr(self, "streaming_mixing_weights", None) + + # Validate strategy values + valid_strategies = ["round_robin", "weighted", "random"] + if strategy not in valid_strategies: + raise ValueError( + f"streaming_dataset_mixing_strategy must be one of {valid_strategies}, " + f"got '{strategy}'" + ) + + if strategy == "weighted": + if weights is None: + raise ValueError( + "streaming_mixing_weights must be provided when " + "streaming_dataset_mixing_strategy='weighted'" + ) + + if not isinstance(weights, list) or not all( + isinstance(w, (int, float)) for w in weights + ): + raise ValueError("streaming_mixing_weights must be a list of numbers") + + if any(w < 0 for w in weights): + raise ValueError("streaming_mixing_weights must be non-negative") + + if abs(sum(weights) - 1.0) > 1e-6: + raise ValueError( + f"streaming_mixing_weights must sum to 1.0, got {sum(weights)}" + ) + + elif weights is not None and strategy != "weighted": + LOG.warning( + f"streaming_mixing_weights provided but strategy is '{strategy}'. " + "Weights will be ignored." + ) + + return self + # pylint: disable=too-many-ancestors class ValidationMixin(