Compare commits

...

10 Commits

Author SHA1 Message Date
Dan Saunders
78a039e1be add depr warning for preprocess --iterable 2025-08-22 16:02:30 +00:00
Dan Saunders
69f356163e fix 2025-08-22 16:02:30 +00:00
Dan Saunders
53bbca2591 bugfix for sample packing 2025-08-22 16:02:30 +00:00
Dan Saunders
49bd6ece4a remove unused 2025-08-22 16:02:30 +00:00
Dan Saunders
42b38a718a remove eval streaming (not HF supported) 2025-08-22 16:02:30 +00:00
Dan Saunders
4121bcbc33 fix kd test 2025-08-22 16:02:30 +00:00
Dan Saunders
0caa24eab0 comments 2025-08-22 16:02:30 +00:00
Dan Saunders
68bb70bbae fix test 2025-08-22 16:02:30 +00:00
Dan Saunders
5d8d7ef327 lint 2025-08-22 16:02:30 +00:00
Dan Saunders
7836da9ed9 remove unuse 2025-08-22 16:02:30 +00:00
13 changed files with 138 additions and 422 deletions

View File

@@ -13,6 +13,16 @@ class PreprocessCliArgs:
debug_num_examples: int = field(default=1) debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None) prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True) download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(
default=False,
metadata={
"help": (
"[DEPRECATED] No longer supported. For streaming datasets, use "
"'axolotl train' and set 'streaming: true' in your YAML config, or "
"pass --streaming instead in the CLI."
)
},
)
@dataclass @dataclass

View File

@@ -35,10 +35,20 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
check_accelerate_default_config() check_accelerate_default_config()
check_user_token() check_user_token()
if cli_args.iterable:
LOG.error(
"The --iterable CLI argument for 'axolotl preprocess' is no longer "
"supported. For training, set 'streaming: true' in your YAML config or "
"pass '--streaming' in your 'axolotl train' command for on-the-fly "
"preprocessing."
)
return
for key in ["skip_prepare_dataset", "pretraining_dataset"]: for key in ["skip_prepare_dataset", "pretraining_dataset"]:
if cfg.get(key): if cfg.get(key):
LOG.error( LOG.error(
f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead." f"You have set `{key}:`. `preprocess` is not needed. Run the 'axolotl "
"train' CLI directly instead."
) )
return return

View File

@@ -2,15 +2,12 @@
Module containing dataset functionality. Module containing dataset functionality.
We want this to be a wrapper for an existing dataset that we have loaded. Lets use the We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
concept of middlewares to wrap each dataset, for example: concept of middlewares to wrap each dataset. We'll use the collators later on to pad the
ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)])). datasets.
Let's check to ensure we don't truncate an item in the middle. We'll use the collators
later on to pad the datasets.
""" """
from typing import Any from typing import Any
import torch
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset
from axolotl.utils.logging import get_logger from axolotl.utils.logging import get_logger
@@ -48,8 +45,6 @@ class TokenizedPromptDataset(Dataset):
def process(self, dataset: Dataset | IterableDataset) -> Dataset | IterableDataset: def process(self, dataset: Dataset | IterableDataset) -> Dataset | IterableDataset:
"""Apply filtering and tokenization.""" """Apply filtering and tokenization."""
# For IterableDataset, we can't access features up front. Anyways, we don't care
# to remove unused columns from streaming datasets.
features = None features = None
if not isinstance(dataset, IterableDataset): if not isinstance(dataset, IterableDataset):
features = dataset.features.keys() features = dataset.features.keys()
@@ -98,139 +93,16 @@ def wrap_dataset_for_tokenized_prompt(
map_kwargs = {} map_kwargs = {}
if prompt_tokenizer.supports_batched: if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True map_kwargs["batched"] = True
# Map the dataset and remove original columns
# For IterableDataset, features might be None until first iteration
remove_columns = None
if dataset.features is not None:
remove_columns = list(dataset.features.keys())
return dataset.map( return dataset.map(
prompt_tokenizer.tokenize_prompt, prompt_tokenizer.tokenize_prompt,
remove_columns=remove_columns,
**map_kwargs, **map_kwargs,
) )
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs) return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
# TODO: this isn't the best since it can't interleave datasets.
# NOTE: this is only used in a test. Can it be deleted?
class ConstantLengthDataset(IterableDataset):
"""Iterable dataset that returns constant length chunks of tokens from stream of
text files.
Args:
tokenizer: The processor used for processing the data.
dataset: Dataset with text files.
seq_length: Length of token sequences to return.
"""
def __init__( # pylint: disable=super-init-not-called
self,
tokenizer,
datasets,
seq_length=2048,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.datasets: list[IterableDataset] = datasets
self.seq_length = seq_length
vocab_size = len(tokenizer.get_vocab())
if vocab_size <= torch.iinfo(torch.int16).max:
self.tokens_dtype = torch.int16
elif vocab_size <= torch.iinfo(torch.int32).max:
self.tokens_dtype = torch.int32
else:
self.tokens_dtype = torch.int64
def __iter__(self):
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
for dataset in self.datasets:
idx = 0
iterator = iter(dataset)
more_examples = True
while more_examples:
try:
example = next(iterator)
idx += 1
except StopIteration:
more_examples = False
example = None
add_concat_token = False
if example:
example_len = len(example["input_ids"])
add_concat_token = example["input_ids"][-1] != self.concat_token_id
else:
example_len = 0
if not example_len or (
buffer_len + int(add_concat_token) + example_len > self.seq_length
):
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
: self.seq_length
]
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
):
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
LOG.warning(
"Dropping batch due to tensor size mismatch "
f"input_ids: {input_ids.size()}, "
f"labels: {labels.size()}, "
f"attention_mask: {attention_mask.size()}"
)
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
idx = 1
if example:
# FIXME
# just going to drop data points that are too long
if len(example["input_ids"]) <= self.seq_length:
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if add_concat_token:
input_ids.append(self.concat_token_id)
attention_mask.append(1)
labels.append(self.concat_token_id)
input_ids_with_concat = torch.tensor(
input_ids, dtype=self.tokens_dtype
)
attention_mask_with_concat = torch.tensor(
[idx * m for m in attention_mask], dtype=torch.int16
)
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids)

View File

@@ -44,46 +44,17 @@ from axolotl.utils.trainer import (
LOG = get_logger(__name__) LOG = get_logger(__name__)
def _is_streaming_enabled_for_split( def _is_streaming_enabled(cfg: DictDefault) -> bool:
cfg: DictDefault, split: Literal["train", "test"]
) -> bool:
"""Check if streaming is enabled for a specific split.""" """Check if streaming is enabled for a specific split."""
if split == "test":
# For eval datasets, check eval_streaming first, then fall back to streaming
eval_streaming = cfg.get("eval_streaming")
if eval_streaming is not None:
return eval_streaming
# Fall back to main streaming setting
streaming = cfg.get("streaming") streaming = cfg.get("streaming")
if streaming is True: if streaming is True:
return True return True
# Check if pretraining dataset exists (defaults to streaming) # Check if pretraining dataset exists (defaults to streaming)
has_pretraining = cfg.get("pretraining_dataset") is not None has_pretraining = cfg.get("pretraining_dataset") is not None
streaming_default_for_pretraining = has_pretraining and streaming is None streaming = has_pretraining and streaming is None
return streaming_default_for_pretraining return streaming
def _get_streaming_config_for_split(
cfg: DictDefault, split: Literal["train", "test"]
) -> DictDefault:
"""Get a modified config object with split-specific streaming settings."""
if split != "test":
return cfg
# Override with eval-specific configs if they exist
streaming_cfg = DictDefault(cfg)
eval_strategy = cfg.get("eval_dataset_mixing_strategy")
eval_weights = cfg.get("eval_mixing_weights")
if eval_strategy is not None:
streaming_cfg["dataset_mixing_strategy"] = eval_strategy
if eval_weights is not None:
streaming_cfg["mixing_weights"] = eval_weights
return streaming_cfg
@retry_on_request_exceptions(max_retries=3, delay=5) @retry_on_request_exceptions(max_retries=3, delay=5)
@@ -145,7 +116,6 @@ def _prepare_standard_dataset(
return train_dataset, eval_dataset, -1, prompters return train_dataset, eval_dataset, -1, prompters
# Validate sample packing configuration for evaluation # Validate sample packing configuration for evaluation
# Skip validation for streaming eval datasets since theWhat hy don't have a calculable length
if ( if (
eval_dataset eval_dataset
and cfg.sample_packing and cfg.sample_packing
@@ -315,14 +285,14 @@ def _load_tokenized_prepared_datasets(
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
prompters: list[Prompter | None] = [] prompters: list[Prompter | None] = []
# Check if streaming is enabled for this split use_streaming = False
use_streaming = _is_streaming_enabled_for_split(cfg, split) if split == "train":
use_streaming = _is_streaming_enabled(cfg)
if use_streaming: if use_streaming:
# For streaming datasets, skip caching and load raw datasets directly # For streaming datasets, skip caching and load raw datasets directly
streaming_cfg = _get_streaming_config_for_split(cfg, split)
dataset, prompters = _load_raw_datasets( dataset, prompters = _load_raw_datasets(
streaming_cfg, cfg,
datasets_configs, datasets_configs,
tokenizer, tokenizer,
split, split,
@@ -417,9 +387,12 @@ def _load_and_process_single_dataset(
processor: ProcessorMixin | None = None, processor: ProcessorMixin | None = None,
) -> tuple[Dataset | IterableDataset, Prompter | None]: ) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Load and process a single dataset based on the passed config.""" """Load and process a single dataset based on the passed config."""
use_streaming_for_split = _is_streaming_enabled_for_split(cfg, split) use_streaming = False
if split == "train":
use_streaming = _is_streaming_enabled(cfg)
dataset = load_dataset_with_config( dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, use_streaming_for_split dataset_config, cfg.hf_use_auth_token, use_streaming
) )
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type) d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)

View File

@@ -593,7 +593,6 @@ def _merge_datasets_with_strategy(
LOG.info(f"Merging datasets with mixing strategy: {strategy}...") LOG.info(f"Merging datasets with mixing strategy: {strategy}...")
if strategy == "concatenate": if strategy == "concatenate":
# Concatenate only works with non-iterable datasets
if not all(isinstance(ds, Dataset) for ds in datasets): if not all(isinstance(ds, Dataset) for ds in datasets):
raise ValueError( raise ValueError(
"Cannot concatenate streaming datasets. Use 'round_robin', 'weighted', " "Cannot concatenate streaming datasets. Use 'round_robin', 'weighted', "
@@ -605,7 +604,6 @@ def _merge_datasets_with_strategy(
if strategy == "weighted": if strategy == "weighted":
return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed) return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed)
if strategy == "random": if strategy == "random":
# Random sampling with equal probability
equal_weights = [1.0 / len(datasets)] * len(datasets) equal_weights = [1.0 / len(datasets)] * len(datasets)
return interleave_datasets(datasets, probabilities=equal_weights, seed=cfg.seed) return interleave_datasets(datasets, probabilities=equal_weights, seed=cfg.seed)
raise ValueError(f"Unknown dataset mixing strategy: {strategy}") raise ValueError(f"Unknown dataset mixing strategy: {strategy}")

View File

@@ -100,6 +100,10 @@ def get_dataset_wrapper(
dataset_config, tokenizer, cfg, dataset, dataset_kwargs dataset_config, tokenizer, cfg, dataset, dataset_kwargs
) )
# Skip preparation if configured
if cfg.skip_prepare_dataset:
return dataset, None
# Bradley-Terry dataset # Bradley-Terry dataset
if dataset_config.type.startswith("bradley_terry"): if dataset_config.type.startswith("bradley_terry"):
return _handle_bradley_terry_dataset( return _handle_bradley_terry_dataset(

View File

@@ -938,12 +938,6 @@ class AxolotlInputConfig(
"description": "Whether to use streaming datasets (IterableDataset) for training datasets. When True, data is loaded on-demand during training without upfront preprocessing. Requires max_steps to be set. Pre-training datasets default to streaming unless explicitly set to False." "description": "Whether to use streaming datasets (IterableDataset) for training datasets. When True, data is loaded on-demand during training without upfront preprocessing. Requires max_steps to be set. Pre-training datasets default to streaming unless explicitly set to False."
}, },
) )
eval_streaming: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use streaming datasets for evaluation datasets. If not set, falls back to the 'streaming' setting. Useful for streaming large training data while keeping smaller eval datasets in memory."
},
)
dataset_mixing_strategy: str | None = Field( dataset_mixing_strategy: str | None = Field(
default="round_robin", default="round_robin",
json_schema_extra={ json_schema_extra={
@@ -956,18 +950,6 @@ class AxolotlInputConfig(
"description": "Weights for weighted mixing strategy when using multiple datasets. Must sum to 1.0 and have same length as datasets list. Only used when dataset_mixing_strategy='weighted'." "description": "Weights for weighted mixing strategy when using multiple datasets. Must sum to 1.0 and have same length as datasets list. Only used when dataset_mixing_strategy='weighted'."
}, },
) )
eval_dataset_mixing_strategy: str | None = Field(
default=None,
json_schema_extra={
"description": "Strategy for mixing multiple evaluation datasets. If not set, falls back to dataset_mixing_strategy. Options: 'concatenate', 'round_robin', 'weighted', 'random'."
},
)
eval_mixing_weights: list[float] | None = Field(
default=None,
json_schema_extra={
"description": "Weights for weighted mixing strategy for evaluation datasets. Must sum to 1.0 and have same length as evaluation datasets list."
},
)
# INTERNALS - document for now, generally not set externally # INTERNALS - document for now, generally not set externally
is_preprocess: bool | None = None is_preprocess: bool | None = None

View File

@@ -1130,14 +1130,11 @@ class PretrainingValidationMixin:
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_streaming_split_batches_accelerate(cls, data): def check_streaming_split_batches_accelerate(cls, data):
# Check if either training or eval uses streaming # Check if streaming is enabled for training
streaming = data.get("streaming", False) streaming = data.get("streaming", False)
eval_streaming = data.get("eval_streaming")
if eval_streaming is None:
eval_streaming = streaming
# If either training or eval uses streaming, configure accelerator # If streaming is enabled, configure accelerator
if streaming or eval_streaming: if streaming:
accelerator_config = data.get("accelerator_config", {}) accelerator_config = data.get("accelerator_config", {})
if not accelerator_config: if not accelerator_config:
data["accelerator_config"] = { data["accelerator_config"] = {
@@ -1412,13 +1409,8 @@ class GRPOVllmValidationMixin:
class StreamingValidationMixin: class StreamingValidationMixin:
"""Validation methods related to streaming datasets.""" """Validation methods related to streaming datasets."""
def _is_streaming_enabled(self, context: str = "train") -> bool: def _is_streaming_enabled(self) -> bool:
"""Check if streaming is enabled for a given context (train or eval).""" """Check if streaming is enabled."""
if context == "eval":
eval_streaming = getattr(self, "eval_streaming", None)
if eval_streaming is not None:
return eval_streaming
# Fall back to main streaming setting # Fall back to main streaming setting
streaming = getattr(self, "streaming", None) streaming = getattr(self, "streaming", None)
if streaming is True: if streaming is True:
@@ -1426,15 +1418,15 @@ class StreamingValidationMixin:
# Check if pretraining dataset exists (defaults to streaming) # Check if pretraining dataset exists (defaults to streaming)
has_pretraining = getattr(self, "pretraining_dataset", None) is not None has_pretraining = getattr(self, "pretraining_dataset", None) is not None
streaming_default_for_pretraining = has_pretraining and streaming is None streaming = has_pretraining and streaming is None
return streaming_default_for_pretraining return streaming
@model_validator(mode="after") @model_validator(mode="after")
def check_streaming_requires_max_steps(self): def check_streaming_requires_max_steps(self):
"""Ensure max_steps is set when using streaming datasets.""" """Ensure max_steps is set when using streaming datasets."""
# Check if streaming is enabled for training datasets # Check if streaming is enabled for training datasets
if self._is_streaming_enabled("train"): if self._is_streaming_enabled():
max_steps = getattr(self, "max_steps", None) max_steps = getattr(self, "max_steps", None)
if not max_steps: if not max_steps:
raise ValueError("max_steps must be set when using streaming datasets") raise ValueError("max_steps must be set when using streaming datasets")
@@ -1445,11 +1437,12 @@ class StreamingValidationMixin:
def check_streaming_validation_splits_conflict(self): def check_streaming_validation_splits_conflict(self):
"""Ensure validation splits are not used with streaming datasets.""" """Ensure validation splits are not used with streaming datasets."""
# Check if streaming is enabled for training datasets # Check if streaming is enabled for training datasets
if self._is_streaming_enabled("train"): if self._is_streaming_enabled():
val_set_size = getattr(self, "val_set_size", 0.0) val_set_size = getattr(self, "val_set_size", 0.0)
if val_set_size and val_set_size > 0: if val_set_size and val_set_size > 0:
raise ValueError( raise ValueError(
"Validation splits not supported for streaming datasets, skipping" "Validation splits not supported for streaming datasets, please "
"use test_datasets: ... instead"
) )
return self return self
@@ -1457,28 +1450,13 @@ class StreamingValidationMixin:
@model_validator(mode="after") @model_validator(mode="after")
def check_streaming_preprocessing_conflict(self): def check_streaming_preprocessing_conflict(self):
"""Ensure preprocessing is not enabled with streaming datasets.""" """Ensure preprocessing is not enabled with streaming datasets."""
# Check if streaming is enabled for training or eval datasets # Check if streaming is enabled for training datasets
if self._is_streaming_enabled("train") or self._is_streaming_enabled("eval"): if self._is_streaming_enabled():
if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1": if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1":
raise ValueError("preprocess is not supported for streaming datasets") raise ValueError("preprocess is not supported for streaming datasets")
return self return self
@model_validator(mode="after")
def check_streaming_skip_prepare_dataset(self):
"""Ensure skip_prepare_dataset is set for streaming datasets."""
# Check if streaming is enabled for training or eval datasets
if self._is_streaming_enabled("train") or self._is_streaming_enabled("eval"):
skip_prepare = getattr(self, "skip_prepare_dataset", None)
if skip_prepare is False:
LOG.warning(
"skip_prepare_dataset=False is not compatible with streaming "
"datasets. Setting skip_prepare_dataset=True."
)
self.skip_prepare_dataset = True
return self
@model_validator(mode="after") @model_validator(mode="after")
def check_dataset_mixing_weights(self): def check_dataset_mixing_weights(self):
"""Validate dataset mixing weights configuration.""" """Validate dataset mixing weights configuration."""
@@ -1486,7 +1464,6 @@ class StreamingValidationMixin:
# Get datasets to validate length against # Get datasets to validate length against
datasets = getattr(self, "datasets", None) datasets = getattr(self, "datasets", None)
test_datasets = getattr(self, "test_datasets", None)
# Check main strategy and weights # Check main strategy and weights
strategy = getattr(self, "dataset_mixing_strategy", "concatenate") strategy = getattr(self, "dataset_mixing_strategy", "concatenate")
@@ -1502,26 +1479,6 @@ class StreamingValidationMixin:
dataset_count, dataset_count,
) )
# Check eval-specific strategy and weights
eval_strategy = getattr(self, "eval_dataset_mixing_strategy", None)
eval_weights = getattr(self, "eval_mixing_weights", None)
if eval_strategy is not None:
eval_dataset_count = len(test_datasets) if test_datasets else dataset_count
self._validate_dataset_strategy_and_weights(
eval_strategy,
eval_weights,
"eval_dataset_mixing_strategy",
"eval_mixing_weights",
valid_strategies,
eval_dataset_count,
)
elif eval_weights is not None:
LOG.warning(
"eval_mixing_weights provided but eval_dataset_mixing_strategy is not set. "
"Weights will be ignored unless eval_dataset_mixing_strategy='weighted'."
)
return self return self
def _validate_dataset_strategy_and_weights( def _validate_dataset_strategy_and_weights(

View File

@@ -10,7 +10,6 @@ from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
import torch.cuda
from datasets import IterableDataset, disable_caching, enable_caching from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
@@ -23,6 +22,65 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__) LOG = get_logger(__name__)
def _create_filtered_iterable_dataset(dataset, filter_fn, batched=False):
"""
Create a filtered IterableDataset that works around a HuggingFace datasets
limitation.
"""
def filtered_generator():
"""Generator that yields only samples that pass the filter function."""
if batched:
batch = []
batch_size = 1000 # Process in batches of 1000
for sample in dataset:
batch.append(sample)
if len(batch) >= batch_size:
# Create a batch dict from list of samples
batch_dict = {}
for key in batch[0].keys():
batch_dict[key] = [sample[key] for sample in batch]
# Apply filter function to batch
keep_mask = filter_fn(batch_dict)
# Yield samples that should be kept
for i, keep in enumerate(keep_mask):
if keep:
yield batch[i]
batch = []
# Process remaining samples in batch
if batch:
batch_dict = {}
for key in batch[0].keys():
batch_dict[key] = [sample[key] for sample in batch]
keep_mask = filter_fn(batch_dict)
for i, keep in enumerate(keep_mask):
if keep:
yield batch[i]
else:
# For non-batched filtering, apply filter to each sample individually
for sample in dataset:
if filter_fn(sample):
yield sample
# Create new IterableDataset from the filtered generator
filtered_dataset = IterableDataset.from_generator(filtered_generator)
# Preserve the original features if they exist
# pylint:disable=protected-access
if hasattr(dataset, "_info") and dataset._info.features is not None:
filtered_dataset._info.features = dataset._info.features
return filtered_dataset
@torch.jit.script @torch.jit.script
def weighted_cross_entropy( def weighted_cross_entropy(
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
@@ -282,12 +340,21 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long_kwargs = {} drop_long_kwargs = {}
if filter_map_kwargs: if filter_map_kwargs:
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens" drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
train_dataset = train_dataset.filter(
drop_no_trainable_tokens, # For IterableDatasets, always use custom filtering to avoid features issues
batched=True, if isinstance(train_dataset, IterableDataset):
**filter_map_kwargs, # IterableDatasets often have None features after transformations,
**drop_long_kwargs, # so we use our custom filter implementation that doesn't rely on features
) train_dataset = _create_filtered_iterable_dataset(
train_dataset, drop_no_trainable_tokens, batched=True
)
else:
train_dataset = train_dataset.filter(
drop_no_trainable_tokens,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len: if prior_len:
dropped = prior_len - len(train_dataset) dropped = prior_len - len(train_dataset)
if dropped: if dropped:
@@ -472,7 +539,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
) )
data_loader = DataLoader( data_loader = DataLoader(
train_dataset.remove_columns(["length"]), train_dataset,
batch_sampler=sampler, batch_sampler=sampler,
) )
data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size

View File

@@ -25,7 +25,7 @@ def min_cfg(temp_dir):
"liger_rms_norm": True, "liger_rms_norm": True,
"liger_glu_activation": True, "liger_glu_activation": True,
"torch_compile": True, "torch_compile": True,
"chat_template": "llama3", "chat_template": "qwen3",
"kd_trainer": True, "kd_trainer": True,
"kd_ce_alpha": 0.1, "kd_ce_alpha": 0.1,
"kd_alpha": 0.9, "kd_alpha": 0.9,

View File

@@ -1,5 +1,7 @@
"""E2E tests for streaming dataset functionality""" """E2E tests for streaming dataset functionality"""
# pylint: disable=duplicate-code
import pytest import pytest
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
@@ -83,84 +85,6 @@ class TestStreamingDatasets:
"Train Loss (%s) is too high", "Train Loss (%s) is too high",
) )
def test_streaming_eval_specific_mixing(self, temp_dir):
"""Test eval-specific mixing strategy override"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": True,
"sequence_len": 512,
"sample_packing": False,
"dataset_processes": 1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"test_datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
"split": "train", # Specify train split for eval dataset
},
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train", # Specify train split for eval dataset
},
],
# Streaming config
"streaming": True,
"eval_streaming": True,
"max_steps": 3,
# Different mixing for train vs eval
"dataset_mixing_strategy": "round_robin",
"eval_dataset_mixing_strategy": "weighted",
"eval_mixing_weights": [0.6, 0.4],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
"eval_steps": 3, # Eval at the end
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
# Check both train and eval losses
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
2.5,
"Train Loss (%s) is too high",
)
check_tensorboard(
temp_dir + "/runs",
"eval/eval_loss",
2.5,
"Eval Loss (%s) is too high",
)
def test_streaming_validation_error(self, temp_dir): def test_streaming_validation_error(self, temp_dir):
"""Test that pydantic validation catches invalid streaming configs""" """Test that pydantic validation catches invalid streaming configs"""

View File

@@ -664,42 +664,3 @@ class TestDatasetPreparation:
# Should have samples from both datasets # Should have samples from both datasets
sources = [sample["source"] for sample in samples] sources = [sample["source"] for sample in samples]
assert len(set(sources)) >= 1 # At least one unique source assert len(set(sources)) >= 1 # At least one unique source
def test_eval_streaming_config(self):
"""Test eval_streaming separate from streaming config."""
from axolotl.utils.data.sft import _is_streaming_enabled_for_split
# Test train streaming enabled, eval streaming disabled
cfg = DictDefault({"streaming": True, "eval_streaming": False})
assert _is_streaming_enabled_for_split(cfg, "train")
assert _is_streaming_enabled_for_split(cfg, "test")
# Test train streaming disabled, eval streaming enabled
cfg2 = DictDefault({"streaming": False, "eval_streaming": True})
assert _is_streaming_enabled_for_split(cfg2, "train")
assert _is_streaming_enabled_for_split(cfg2, "test")
def test_eval_specific_mixing_configs(self):
"""Test eval-specific mixing configs override main configs."""
from axolotl.utils.data.sft import _get_streaming_config_for_split
cfg = DictDefault(
{
"dataset_mixing_strategy": "round_robin",
"mixing_weights": [0.5, 0.5],
"eval_dataset_mixing_strategy": "weighted",
"eval_mixing_weights": [0.8, 0.2],
}
)
# Train split should use main config
train_cfg = _get_streaming_config_for_split(cfg, "train")
assert train_cfg["dataset_mixing_strategy"] == "round_robin"
assert train_cfg["mixing_weights"] == [0.5, 0.5]
# Test split should use eval-specific config
test_cfg = _get_streaming_config_for_split(cfg, "test")
assert test_cfg["dataset_mixing_strategy"] == "weighted"
assert test_cfg["mixing_weights"] == [0.8, 0.2]

View File

@@ -1,16 +1,11 @@
"""Module for testing dataset sequence packing""" """Module for testing dataset sequence packing"""
import unittest import unittest
from pathlib import Path
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.cli.args import TrainerCliArgs from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets from axolotl.common.datasets import load_datasets
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter
from axolotl.train import setup_model_and_trainer from axolotl.train import setup_model_and_trainer
from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -36,43 +31,6 @@ class TestPacking(unittest.TestCase):
} }
) )
def test_increments_attention(self):
prompter = AlpacaPrompter("chat")
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
dateset = load_dataset(
"json",
data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
)["train"]
dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
constant_len_dataset = ConstantLengthDataset(
self.tokenizer,
[dataset],
seq_length=2048,
)
packed_dataset = Dataset.from_list(list(constant_len_dataset))
example = packed_dataset[0]
next_bos_index = (
example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1
) # add one since we sliced
# first example doesn't have mask reset
assert example["input_ids"][0] == self.tokenizer.bos_token_id
assert example["attention_mask"][0] == 1
assert example["position_ids"][0] == 0
assert example["position_ids"][1] == 1
# but subsequent one does
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
assert example["attention_mask"][next_bos_index] == 2
assert example["position_ids"][next_bos_index] == 0
assert example["position_ids"][next_bos_index + 1] == 1
@with_temp_dir @with_temp_dir
def test_lora_packing(self, temp_dir): def test_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code