Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
bb65157dcf fix conditional for None values 2025-08-17 12:49:48 -04:00
Wing Lian
7fd3d8abc4 handle batch size correchtly when using split and dispatch batches 2025-08-16 22:05:31 -04:00
34 changed files with 335 additions and 994 deletions

View File

@@ -12,6 +12,5 @@ reviews:
auto_review:
enabled: true
drafts: false
auto_incremental_review: true
chat:
auto_reply: true

View File

@@ -41,12 +41,6 @@ model, and final model output, you may need at least 3TB of free disk space to k
axolotl train examples/gpt-oss/gpt-oss-120b-fft-fsdp2-offload.yaml
```
To simplify fine-tuning across 2 nodes × 8x H100 (80GB) GPUs, we've partnered with [Baseten](https://baseten.co) to showcase multi-node
training of the 120B model using Baseten Truss. You can read more about this recipe on
[Baseten's blog](https://www.baseten.co/blog/how-to-fine-tune-gpt-oss-120b-with-baseten-and-axolotl/). The recipe can
be found on their
[GitHub](https://github.com/basetenlabs/ml-cookbook/tree/main/examples/oss-gpt-120b-axolotl/training).
ERRATA: Transformers saves the model Architecture prefixed with `FSDP` which needs to be manually renamed in `config.json`.
See https://github.com/huggingface/transformers/pull/40207 for the status of this issue.
@@ -67,23 +61,9 @@ mv ./outputs/gpt-oss-out/merged/* ./outputs/gpt-oss-out/
### Inferencing your fine-tuned model
#### vLLM
GPT-OSS support in vLLM does not exist in a stable release yet. See https://x.com/MaziyarPanahi/status/1955741905515323425
for more information about using a special vllm-openai docker image for inferencing with vLLM.
Optionally, vLLM can be installed from nightly:
```bash
pip install --no-build-isolation --pre -U vllm --extra-index-url https://wheels.vllm.ai/nightly
```
and the vLLM server can be started with the following command (modify `--tensor-parallel-size 8` to match your environment):
```bash
vllm serve ./outputs/gpt-oss-out/ --served-model-name axolotl/gpt-oss-20b --host 0.0.0.0 --port 8888 --tensor-parallel-size 8
```
#### SGLang
SGLang has 0-day support in main, see https://github.com/sgl-project/sglang/issues/8833 for infomation on installing
SGLang from source. Once you've installed SGLang, run the following command to launch a SGLang server:

View File

@@ -44,7 +44,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -40,7 +40,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -15,7 +15,7 @@ datasets:
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: ./outputs/last_run_prepared
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
@@ -41,7 +41,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -15,7 +15,7 @@ datasets:
field_thinking: thinking
template_thinking_key: thinking
dataset_prepared_path: ./outputs/last_run_prepared
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/gpt-oss-out/
@@ -40,7 +40,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -53,7 +53,7 @@ bf16: true
tf32: true
flash_attention: true
attn_implementation: kernels-community/vllm-flash-attn3 # this is not needed if using flash_attn >= 2.8.3
attn_implementation: kernels-community/vllm-flash-attn3
gradient_checkpointing: true
activation_offloading: true

View File

@@ -13,8 +13,8 @@ liger-kernel==0.6.1
packaging==23.2
huggingface_hub>=0.33.0
peft>=0.17.0
transformers==4.55.3
peft==0.17.0
transformers==4.55.2
tokenizers>=0.21.1
accelerate==1.10.0
datasets==4.0.0

View File

@@ -118,9 +118,9 @@ def get_package_version():
extras_require = {
"flash-attn": ["flash-attn==2.8.3"],
"flash-attn": ["flash-attn==2.8.2"],
"ring-flash-attn": [
"flash-attn==2.8.3",
"flash-attn==2.8.2",
"ring-flash-attn>=0.1.7",
"yunchang==0.6.0",
],

View File

@@ -14,13 +14,9 @@ class PreprocessCliArgs:
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(
default=False,
default=None,
metadata={
"help": (
"[DEPRECATED] No longer supported. For streaming datasets, use "
"'axolotl train' and set 'streaming: true' in your YAML config, or "
"pass --streaming instead in the CLI."
)
"help": "Use IterableDataset for streaming processing of large datasets"
},
)
@@ -44,12 +40,6 @@ class VllmServeCliArgs:
default=None,
metadata={"help": "Number of tensor parallel workers to use."},
)
data_parallel_size: Optional[int] = field(
default=None,
metadata={
"help": "Number of data parallel workers to use for vLLM serving. This controls how many model replicas are used for parallel inference."
},
)
host: Optional[str] = field(
default=None, # nosec B104
metadata={"help": "Host address to run the server on."},

View File

@@ -82,7 +82,7 @@ class ModalCloud(Cloud):
return res
def get_image(self):
docker_tag = "main-py3.11-cu126-2.7.1"
docker_tag = "main-py3.11-cu124-2.6.0"
if self.config.docker_tag:
docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}"
@@ -200,7 +200,7 @@ class ModalCloud(Cloud):
if family in ["a10", "a10g"]:
return modal.gpu.A10G(count=count)
if family == "h100":
return f"H100:{count}"
return modal.gpu.H100(count=count)
if family == "t4":
return modal.gpu.T4(count=count)
if family == "l4":

View File

@@ -64,7 +64,7 @@ def do_inference(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer

View File

@@ -35,20 +35,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
check_accelerate_default_config()
check_user_token()
if cli_args.iterable:
LOG.error(
"The --iterable CLI argument for 'axolotl preprocess' is no longer "
"supported. For training, set 'streaming: true' in your YAML config or "
"pass '--streaming' in your 'axolotl train' command for on-the-fly "
"preprocessing."
)
return
for key in ["skip_prepare_dataset", "pretraining_dataset"]:
if cfg.get(key):
LOG.error(
f"You have set `{key}:`. `preprocess` is not needed. Run the 'axolotl "
"train' CLI directly instead."
f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead."
)
return
@@ -107,8 +97,7 @@ def do_cli(
"""
# pylint: disable=duplicate-code
os.environ["AXOLOTL_IS_PREPROCESS"] = "1"
is_preprocess = kwargs.pop("is_preprocess", True)
parsed_cfg = load_cfg(config, is_preprocess=is_preprocess, **kwargs)
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(

View File

@@ -3,12 +3,11 @@
import random
from copy import deepcopy
from itertools import product
from typing import Any
def generate_sweep_configs(
base_config: dict[str, list], sweeps_config: dict[str, list]
) -> list[dict[str, Any]]:
) -> list[dict[str, list]]:
"""
Recursively generates all possible configurations by applying sweeps to the base config.

View File

@@ -4,7 +4,6 @@ import os
import subprocess # nosec
import sys
import tempfile
from pathlib import Path
from typing import Any, Iterator, Literal
import yaml
@@ -89,12 +88,7 @@ def generate_config_files(config: str, sweep: str | None) -> Iterator[tuple[str,
# Generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)
is_group = len(permutations) > 1
base_output_dir = base_config.get("output_dir", "./model-out")
for idx, permutation in enumerate(permutations, start=1):
permutation_dir = Path(permutation.get("output_dir", base_output_dir))
permutation_id = f"sweep{idx:04d}"
permutation["output_dir"] = str(permutation_dir / permutation_id)
for permutation in permutations:
# pylint: disable=consider-using-with
temp_file = tempfile.NamedTemporaryFile(
mode="w",

View File

@@ -55,11 +55,13 @@ def load_datasets(
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = getattr(cli_args, "iterable", False)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
cfg,
tokenizer,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if (

View File

@@ -424,7 +424,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
):
if training_args.pretraining:
if (
self.cfg.pretraining_sample_concatenation is False
not self.cfg.pretraining_sample_concatenation
or self.cfg.micro_batch_size > 1
):
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)

View File

@@ -272,6 +272,20 @@ class AxolotlTrainer(
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
if (
self.args.accelerator_config is not None
and self.args.accelerator_config.split_batches
and self.args.accelerator_config.dispatch_batches
):
if self.args.sample_packing and self.args.pretraining:
if not self.args.eval_sample_packing and not is_training:
dataloader_params["batch_size"] *= self.accelerator.num_processes
else:
dataloader_params["batch_size"] = self.accelerator.num_processes
elif not self.args.sample_packing and self.args.pretraining:
dataloader_params["batch_size"] *= self.accelerator.num_processes
if self.args.sample_packing and (
(is_training and not self.args.pretraining)
or (not is_training and self.args.eval_sample_packing is not False)

View File

@@ -1,19 +1,18 @@
"""
Module containing dataset functionality.
We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
concept of middlewares to wrap each dataset. We'll use the collators later on to pad the
datasets.
"""
from typing import Any
"""Module containing Dataset functionality"""
import torch
from datasets import Dataset, IterableDataset
from axolotl.utils.logging import get_logger
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
LOG = get_logger(__name__)
@@ -43,13 +42,10 @@ class TokenizedPromptDataset(Dataset):
**kwargs,
)
def process(self, dataset: Dataset | IterableDataset) -> Dataset | IterableDataset:
"""Apply filtering and tokenization."""
features = None
if not isinstance(dataset, IterableDataset):
features = dataset.features.keys()
def process(self, dataset):
features = dataset.features.keys()
map_kwargs: dict[str, Any] = {}
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 1_000
@@ -58,28 +54,18 @@ class TokenizedPromptDataset(Dataset):
hasattr(self.prompt_tokenizer, "filter_rows")
and self.prompt_tokenizer.filter_rows
):
filter_kwargs: dict[str, Any] = {"desc": "Strategy Filtering Rows"}
if not isinstance(dataset, IterableDataset):
filter_kwargs["num_proc"] = self.process_count
dataset = dataset.filter(
self.prompt_tokenizer.filter_rows,
**filter_kwargs,
num_proc=self.process_count,
desc="Strategy Filtering Rows",
)
map_kwargs = {
**map_kwargs,
"desc": "Tokenizing Prompts",
}
# Only add remove_columns for regular datasets
if not isinstance(dataset, IterableDataset):
map_kwargs["remove_columns"] = features
map_kwargs["num_proc"] = self.process_count
map_kwargs["keep_in_memory"] = self.keep_in_memory
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=self.process_count,
remove_columns=features,
keep_in_memory=self.keep_in_memory,
desc="Tokenizing Prompts",
**map_kwargs,
)
@@ -93,16 +79,140 @@ def wrap_dataset_for_tokenized_prompt(
map_kwargs = {}
if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
# Map the dataset and remove original columns
# For IterableDataset, features might be None until first iteration
remove_columns = None
if dataset.features is not None:
remove_columns = list(dataset.features.keys())
features = list(dataset.features.keys())
return dataset.map(
prompt_tokenizer.tokenize_prompt,
remove_columns=remove_columns,
remove_columns=features,
**map_kwargs,
)
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
"""Iterable dataset that returns constant length chunks of tokens from stream of
text files.
Args:
tokenizer: The processor used for processing the data.
dataset: Dataset with text files.
seq_length: Length of token sequences to return.
"""
def __init__( # pylint: disable=super-init-not-called
self,
tokenizer,
datasets,
seq_length=2048,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.datasets: list[IterableDataset] = datasets
self.seq_length = seq_length
vocab_size = len(tokenizer.get_vocab())
if vocab_size <= torch.iinfo(torch.int16).max:
self.tokens_dtype = torch.int16
elif vocab_size <= torch.iinfo(torch.int32).max:
self.tokens_dtype = torch.int32
else:
self.tokens_dtype = torch.int64
def __iter__(self):
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
for dataset in self.datasets:
idx = 0
iterator = iter(dataset)
more_examples = True
while more_examples:
try:
example = next(iterator)
idx += 1
except StopIteration:
more_examples = False
example = None
add_concat_token = False
if example:
example_len = len(example["input_ids"])
add_concat_token = example["input_ids"][-1] != self.concat_token_id
else:
example_len = 0
if not example_len or (
buffer_len + int(add_concat_token) + example_len > self.seq_length
):
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
: self.seq_length
]
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
):
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
LOG.warning(
"Dropping batch due to tensor size mismatch "
f"input_ids: {input_ids.size()}, "
f"labels: {labels.size()}, "
f"attention_mask: {attention_mask.size()}"
)
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
idx = 1
if example:
# FIXME
# just going to drop data points that are too long
if len(example["input_ids"]) <= self.seq_length:
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if add_concat_token:
input_ids.append(self.concat_token_id)
attention_mask.append(1)
labels.append(self.concat_token_id)
input_ids_with_concat = torch.tensor(
input_ids, dtype=self.tokens_dtype
)
attention_mask_with_concat = torch.tensor(
[idx * m for m in attention_mask], dtype=torch.int16
)
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids)

View File

View File

@@ -187,7 +187,7 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
if module.base_layer.bias is not None:
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
log_bias_dtype_mismatch = True
module.base_layer.bias.data = module.base_layer.bias.data.to(

View File

@@ -253,9 +253,7 @@ def save_trained_model(
# final model weights have already been saved by `ReLoRACallback.on_train_end`
return
if ( # pylint: disable=too-many-nested-blocks
trainer.is_fsdp_enabled or cfg.fsdp_config
):
if trainer.is_fsdp_enabled or cfg.fsdp_config:
if cfg.fsdp_config or cfg.fsdp:
if cfg.fsdp_config.final_state_dict_type:
state_dict_type = cfg.fsdp_config.final_state_dict_type
@@ -287,8 +285,6 @@ def save_trained_model(
if trainer.accelerator.is_main_process:
# move all files in merged_path to cfg.output_dir
for merged_file in Path(merged_path).iterdir():
if (Path(cfg.output_dir) / merged_file.name).exists():
(Path(cfg.output_dir) / merged_file.name).unlink()
shutil.move(str(merged_file), cfg.output_dir)
shutil.rmtree(merged_path) # remove what should be an empty dir
# TODO(wing):see https://github.com/huggingface/transformers/pull/40207

View File

@@ -9,7 +9,6 @@ from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
load_dataset,
)
from transformers import PreTrainedTokenizer, ProcessorMixin
@@ -29,7 +28,7 @@ from axolotl.utils.data.shared import (
)
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
handle_long_seq_in_dataset,
drop_long_seq_in_dataset,
retry_on_request_exceptions,
)
from axolotl.utils.data.wrappers import get_dataset_wrapper
@@ -44,24 +43,12 @@ from axolotl.utils.trainer import (
LOG = get_logger(__name__)
def _is_streaming_enabled(cfg: DictDefault) -> bool:
"""Check if streaming is enabled for a specific split."""
streaming = cfg.get("streaming")
if streaming is True:
return True
# Check if pretraining dataset exists (defaults to streaming)
has_pretraining = cfg.get("pretraining_dataset") is not None
streaming = has_pretraining and streaming is None
return streaming
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_datasets(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[IterableDataset | Dataset, Dataset | None, int, list[Prompter | None]]:
"""Prepare training and evaluation datasets based on configuration.
@@ -69,19 +56,23 @@ def prepare_datasets(
cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer: Tokenizer to use for processing text.
processor: Optional processor for multimodal datasets.
preprocess_iterable: Whether to use iterable preprocessing.
Returns:
Tuple of (train_dataset, eval_dataset, total_steps, prompters).
"""
if cfg.pretraining_dataset:
return _prepare_pretraining_dataset(cfg, tokenizer, processor)
return _prepare_standard_dataset(cfg, tokenizer, processor)
return _prepare_pretraining_dataset(
cfg, tokenizer, processor, preprocess_iterable
)
return _prepare_standard_dataset(cfg, tokenizer, processor, preprocess_iterable)
def _prepare_standard_dataset(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
preprocess_iterable: bool,
) -> tuple[Dataset, Dataset | None, int, list[Prompter | None]]:
"""Prepare standard (non-pretraining) datasets."""
@@ -92,6 +83,7 @@ def _prepare_standard_dataset(
cfg,
split="train",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
# Overwrite eval_dataset if test data exists
@@ -101,6 +93,7 @@ def _prepare_standard_dataset(
cfg,
split="test",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
return train_dataset, eval_dataset, prompters
@@ -116,12 +109,7 @@ def _prepare_standard_dataset(
return train_dataset, eval_dataset, -1, prompters
# Validate sample packing configuration for evaluation
if (
eval_dataset
and cfg.sample_packing
and cfg.eval_sample_packing is not False
and not isinstance(eval_dataset, IterableDataset)
):
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
if total_eval_steps == 0:
raise ValueError(
@@ -129,17 +117,13 @@ def _prepare_standard_dataset(
"You should set `eval_sample_packing: False` in your config."
)
# Set total_num_steps for training
if isinstance(train_dataset, IterableDataset):
total_num_steps = cfg.max_steps
# Calculate total number of training steps
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
)
else:
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
)
else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
LOG.info(f"Maximum number of steps set at {total_num_steps}")
return train_dataset, eval_dataset, total_num_steps, prompters
@@ -148,6 +132,7 @@ def _prepare_pretraining_dataset(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
preprocess_iterable: bool,
) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]:
"""
Prepare dataset for pretraining mode.
@@ -168,6 +153,7 @@ def _prepare_pretraining_dataset(
cfg,
split="test",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if cfg.dataset_exact_deduplication:
@@ -270,6 +256,7 @@ def _load_tokenized_prepared_datasets(
cfg: DictDefault,
split: Literal["train", "test"] = "train",
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[Dataset | DatasetDict, list[Prompter | None]]:
"""Load or create tokenized and prepared datasets for training or testing.
@@ -278,51 +265,39 @@ def _load_tokenized_prepared_datasets(
cfg: Configuration object.
split: Dataset split to load ('train' or 'test').
processor: Optional processor for multimodal datasets.
preprocess_iterable: Whether to use iterable preprocessing.
Returns:
Tuple of (dataset, prompters list).
"""
# Select correct dataset configuration based on split
datasets_configs = cfg.datasets if split == "train" else cfg.test_datasets
# Generate dataset hash for caching
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
# Try loading from hub if push_dataset_to_hub is configured
dataset = None
if cfg.push_dataset_to_hub:
dataset = try_load_from_hub(cfg, dataset_hash, split)
# If not found on hub, try loading from disk
if dataset is None:
dataset = load_preprocessed_dataset(cfg, dataset_hash)
# If not found on disk or skipping prepared dataset, load and process raw datasets
prompters: list[Prompter | None] = []
use_streaming = False
if split == "train":
use_streaming = _is_streaming_enabled(cfg)
if use_streaming:
# For streaming datasets, skip caching and load raw datasets directly
if dataset is None:
dataset, prompters = _load_raw_datasets(
cfg,
datasets_configs,
tokenizer,
split,
processor,
preprocess_iterable,
)
else:
# Generate dataset hash for caching
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
# Try loading from hub if push_dataset_to_hub is configured
dataset = None
if cfg.push_dataset_to_hub:
dataset = try_load_from_hub(cfg, dataset_hash, split)
# If not found on hub, try loading from disk
if dataset is None:
dataset = load_preprocessed_dataset(cfg, dataset_hash)
# If not found on disk or skipping prepared dataset, load and process raw
# datasets
if dataset is None:
dataset, prompters = _load_raw_datasets(
cfg,
datasets_configs,
tokenizer,
split,
processor,
)
return dataset, prompters
@@ -331,8 +306,9 @@ def _load_raw_datasets(
cfg: DictDefault,
datasets_configs: list,
tokenizer: PreTrainedTokenizer,
split: Literal["train", "test"],
split: str,
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[Dataset, list[Prompter | None]]:
"""Load, process, merge, and save raw datasets."""
LOG.info("Loading raw datasets...", main_process_only=False)
@@ -353,6 +329,7 @@ def _load_raw_datasets(
split=split,
seed=cfg.seed,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
datasets.append(dataset_wrapper)
prompters.append(dataset_prompter)
@@ -362,18 +339,17 @@ def _load_raw_datasets(
if not cfg.skip_prepare_dataset:
if split == "test" and cfg.eval_sequence_len:
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
dataset = drop_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
else:
dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
dataset = drop_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
# Only save regular datasets to disk, not streaming datasets
if not isinstance(dataset, IterableDataset):
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
# Save the prepared dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
return dataset, prompters
@@ -382,22 +358,22 @@ def _load_and_process_single_dataset(
dataset_config: DictDefault,
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
split: Literal["train", "test"],
split: str,
seed: int,
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Load and process a single dataset based on the passed config."""
use_streaming = False
if split == "train":
use_streaming = _is_streaming_enabled(cfg)
# Load the dataset
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, use_streaming
dataset_config, cfg.hf_use_auth_token, streaming=preprocess_iterable
)
# Parse dataset type
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
# Select the appropriate split
if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
if isinstance(dataset, DatasetDict):
if dataset_config.split and dataset_config.split in dataset:
dataset = dataset[dataset_config.split]
elif split in dataset:
@@ -442,13 +418,11 @@ def _parse_dataset_type(d_type: str) -> tuple[str | None, str | None]:
def _handle_train_dataset_split(
dataset: Dataset | IterableDataset, cfg: DictDefault
) -> tuple[Dataset | IterableDataset, Dataset | IterableDataset | None]:
dataset: Dataset, cfg: DictDefault
) -> tuple[Dataset, Dataset | None]:
"""Handle processing for train split, including validation set creation."""
val_set_size = (
int(cfg.val_set_size)
if cfg.val_set_size and cfg.val_set_size > 1
else float(cfg.val_set_size or 0.0)
int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size)
)
if val_set_size:
@@ -459,33 +433,27 @@ def _handle_train_dataset_split(
return train_dataset, eval_dataset
# No validation split - apply deduplication if needed and return as train dataset
if cfg.dataset_exact_deduplication and not isinstance(dataset, IterableDataset):
if cfg.dataset_exact_deduplication:
train_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
else:
if cfg.dataset_exact_deduplication and isinstance(dataset, IterableDataset):
LOG.info("Deduplication skipped for streaming datasets (not compatible)")
train_dataset = dataset
return train_dataset, None
def _handle_test_dataset_split(
dataset: Dataset | IterableDataset, cfg: DictDefault
) -> tuple[None, Dataset | IterableDataset | None]:
dataset: Dataset, cfg: DictDefault
) -> tuple[None, Dataset | None]:
"""Handle processing for test split."""
if cfg.dataset_exact_deduplication and not isinstance(dataset, IterableDataset):
if cfg.dataset_exact_deduplication:
eval_dataset, _ = deduplicate_and_log_datasets(dataset=dataset)
else:
if cfg.dataset_exact_deduplication and isinstance(dataset, IterableDataset):
LOG.info("Deduplication skipped for streaming datasets (not compatible)")
eval_dataset = dataset
return None, eval_dataset
def _apply_dataset_sharding(
dataset: Dataset | IterableDataset, cfg: DictDefault
) -> Dataset | IterableDataset:
def _apply_dataset_sharding(dataset: Dataset, cfg: DictDefault) -> Dataset:
"""Apply dataset sharding if configured.
Args:
@@ -511,6 +479,7 @@ def _load_and_prepare_datasets(
cfg: DictDefault,
split: Literal["train", "test"] = "train",
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[Dataset | None, Dataset | None, list[Prompter | None]]:
"""Load and prepare datasets with optional validation split and sharding.
@@ -519,6 +488,7 @@ def _load_and_prepare_datasets(
cfg: Configuration object.
split: Dataset split to load ('train' or 'test').
processor: Optional processor for multimodal datasets.
preprocess_iterable: Whether to use iterable preprocessing.
Returns:
Tuple of (train_dataset, eval_dataset, prompters).
@@ -529,6 +499,7 @@ def _load_and_prepare_datasets(
cfg,
split=split,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
# Apply dataset sharding if configured using shared function

View File

@@ -13,7 +13,6 @@ from datasets import (
IterableDataset,
IterableDatasetDict,
concatenate_datasets,
interleave_datasets,
load_dataset,
load_from_disk,
)
@@ -525,9 +524,7 @@ def generate_dataset_hash_from_config(
return str(md5(config_str))
def merge_datasets(
datasets: list[Dataset | IterableDataset], cfg: DictDefault
) -> Dataset | IterableDataset:
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
"""Merge multiple datasets into one with optional shuffling.
Args:
@@ -540,23 +537,23 @@ def merge_datasets(
if len(datasets) == 1:
ds = datasets[0]
if (
cfg.curriculum_sampling
or not cfg.shuffle_merged_datasets
or isinstance(ds, IterableDataset)
):
# Do not shuffle if curriculum sampling is enabled or
# shuffle_merged_datasets is disabled
if cfg.curriculum_sampling or not cfg.shuffle_merged_datasets:
return ds
return ds.shuffle(seed=cfg.seed)
if cfg.shuffle_before_merging_datasets and all(
isinstance(ds, Dataset) for ds in datasets
):
# If enabled, shuffle each dataset independently before merging.
# This allows curriculum learning strategies to be applied at the dataset level.
if cfg.shuffle_before_merging_datasets:
LOG.info("Shuffling each dataset individually before merging...")
datasets = [ds.shuffle(seed=cfg.seed) for ds in datasets]
merged_dataset = _merge_datasets_with_strategy(datasets, cfg)
LOG.info("Merging datasets...")
merged_dataset = concatenate_datasets(datasets)
if cfg.shuffle_merged_datasets and not isinstance(merged_dataset, IterableDataset):
if cfg.shuffle_merged_datasets:
LOG.debug("Shuffling merged datasets...")
if cfg.curriculum_sampling:
LOG.warning(
@@ -565,45 +562,6 @@ def merge_datasets(
)
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
else:
if isinstance(merged_dataset, IterableDataset):
LOG.debug("Skipping shuffle for streaming datasets.")
else:
LOG.debug("Not shuffling merged datasets.")
LOG.debug("Not shuffling merged datasets.")
return merged_dataset
def _merge_datasets_with_strategy(
datasets: list[Dataset | IterableDataset], cfg: DictDefault
) -> Dataset | IterableDataset:
"""
Merge datasets using the configured mixing strategy. Works with streaming and non-
streaming datasets.
Args:
datasets: List of datasets to merge.
cfg: Configuration object containing mixing settings.
Returns:
Merged dataset (Dataset or IterableDataset depending on inputs).
"""
strategy = cfg.get("dataset_mixing_strategy", "concatenate")
weights = cfg.get("mixing_weights", None)
LOG.info(f"Merging datasets with mixing strategy: {strategy}...")
if strategy == "concatenate":
if not all(isinstance(ds, Dataset) for ds in datasets):
raise ValueError(
"Cannot concatenate streaming datasets. Use 'round_robin', 'weighted', "
"or 'random' instead."
)
return concatenate_datasets(datasets)
if strategy == "round_robin":
return interleave_datasets(datasets, seed=cfg.seed)
if strategy == "weighted":
return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed)
if strategy == "random":
equal_weights = [1.0 / len(datasets)] * len(datasets)
return interleave_datasets(datasets, probabilities=equal_weights, seed=cfg.seed)
raise ValueError(f"Unknown dataset mixing strategy: {strategy}")

View File

@@ -148,36 +148,7 @@ def deduplicate_and_log_datasets(
return dataset, other_dataset
def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
"""
Truncate samples whose sequence length is too long (> sequence_len)
or drop those too short (< min_sequence_len).
"""
min_sequence_len = min_sequence_len or 2
input_ids = sample["input_ids"]
results = []
# Batched (input_ids is a list of lists)
for i, seq in enumerate(input_ids):
length = len(seq)
if length < min_sequence_len:
results.append(False)
elif length > sequence_len:
sample["input_ids"][i] = seq[:sequence_len]
if "attention_mask" in sample:
sample["attention_mask"][i] = sample["attention_mask"][i][:sequence_len]
if "labels" in sample:
sample["labels"][i] = sample["labels"][i][:sequence_len]
if "position_ids" in sample:
sample["position_ids"][i] = sample["position_ids"][i][:sequence_len]
results.append(True)
else:
results.append(True)
return results
def handle_long_seq_in_dataset(
def drop_long_seq_in_dataset(
dataset: Dataset, sequence_len: int, cfg: DictDefault
) -> Dataset:
"""Remove sequences longer than configured maximum from dataset.
@@ -190,15 +161,11 @@ def handle_long_seq_in_dataset(
Returns:
Filtered dataset with long sequences removed.
"""
if hasattr(dataset, "column_names") and dataset.column_names:
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This "
"is expected for reward modeling."
)
return dataset
elif isinstance(dataset, IterableDataset):
LOG.info("Skipping drop_long_seq for streaming datasets (not compatible)")
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
)
return dataset
drop_long = functools.partial(
@@ -225,21 +192,8 @@ def handle_long_seq_in_dataset(
if filter_map_kwargs:
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
excess_length_strategy = (cfg.excess_length_strategy or "drop").lower()
if excess_length_strategy == "truncate":
process_fn = functools.partial(
truncate_long_seq,
sequence_len=sequence_len,
min_sequence_len=cfg.min_sample_len,
)
drop_long_kwargs["desc"] = (
f"Truncating/Filtering Sequences (target_len={sequence_len})"
)
else:
process_fn = drop_long
dataset = dataset.filter(
process_fn,
drop_long,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
@@ -247,11 +201,6 @@ def handle_long_seq_in_dataset(
if prior_len:
dropped = prior_len - len(dataset)
if dropped:
action = (
"truncated/filtered"
if excess_length_strategy == "truncate"
else "dropped"
)
LOG.warning(f"{action.title()} {dropped} samples from dataset")
LOG.warning(f"Dropped {dropped} long samples from dataset")
return dataset

View File

@@ -414,12 +414,6 @@ class AxolotlInputConfig(
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
},
)
excess_length_strategy: Literal["drop", "truncate"] | None = Field(
default=None,
json_schema_extra={
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility."
},
)
eval_sequence_len: int | None = Field(
default=None,
json_schema_extra={
@@ -932,27 +926,9 @@ class AxolotlInputConfig(
fix_untrained_tokens: int | list[int] | None = None
streaming: bool | None = Field(
default=None,
json_schema_extra={
"description": "Whether to use streaming datasets (IterableDataset) for training datasets. When True, data is loaded on-demand during training without upfront preprocessing. Requires max_steps to be set. Pre-training datasets default to streaming unless explicitly set to False."
},
)
dataset_mixing_strategy: str | None = Field(
default="round_robin",
json_schema_extra={
"description": "Strategy for mixing multiple datasets: 'concatenate', 'round_robin' (equal sampling), 'weighted' (use mixing_weights), or 'random' (random sampling with equal probability). Works for both streaming and non-streaming datasets."
},
)
mixing_weights: list[float] | None = Field(
default=None,
json_schema_extra={
"description": "Weights for weighted mixing strategy when using multiple datasets. Must sum to 1.0 and have same length as datasets list. Only used when dataset_mixing_strategy='weighted'."
},
)
# INTERNALS - document for now, generally not set externally
is_preprocess: bool | None = None
preprocess_iterable: bool | None = None
total_num_tokens: int | None = Field(
default=None,

View File

@@ -161,12 +161,7 @@ class HyperparametersConfig(BaseModel):
max_grad_norm: float | None = Field(
default=None, json_schema_extra={"description": "Gradient clipping max norm"}
)
num_epochs: float = Field(
default=1.0,
json_schema_extra={
"description": "Number of iterations over dataset for training"
},
)
num_epochs: float = Field(default=1.0)
@field_validator("batch_size")
@classmethod

View File

@@ -3,7 +3,6 @@
# pylint: disable=too-many-boolean-expressions
import json
import os
import sys
import tempfile
from pathlib import Path
@@ -193,7 +192,6 @@ class AttentionValidationMixin:
return data
# pylint: disable=too-many-public-methods
class TrainingValidationMixin:
"""Validation methods related to training configuration."""
@@ -510,58 +508,11 @@ class TrainingValidationMixin:
# combining these would raise `TypeError: cannot pickle 'dict_keys' object`
# due to trying to count the number of tokens total in the dataset
raise ValueError(
"pretraining_dataset and include_tokens_per_second cannot be used "
"together."
"pretraining_dataset and include_tokens_per_second cannot be used together."
)
return data
@model_validator(mode="before")
@classmethod
def check_max_steps_num_epochs_conflict(cls, data):
"""Handle max_steps and num_epochs configuration and auto-set defaults."""
max_steps = data.get("max_steps")
num_epochs = data.get("num_epochs")
# Auto-set num_epochs to 1 if neither max_steps nor num_epochs are set
if max_steps is None and num_epochs is None:
data["num_epochs"] = 1.0
return data
@model_validator(mode="before")
@classmethod
def check_saves_per_epoch_conflicts(cls, data):
"""Ensure saves_per_epoch is compatible with training configuration."""
saves_per_epoch = data.get("saves_per_epoch")
num_epochs = data.get("num_epochs")
if saves_per_epoch is not None:
# Check if saves_per_epoch is set but num_epochs is unset
if num_epochs is None:
raise ValueError(
"saves_per_epoch requires num_epochs to be set to calculate save "
"intervals."
)
return data
@model_validator(mode="before")
@classmethod
def check_evals_per_epoch_conflicts(cls, data):
"""Ensure evals_per_epoch is compatible with training configuration."""
evals_per_epoch = data.get("evals_per_epoch")
num_epochs = data.get("num_epochs")
if evals_per_epoch is not None:
if num_epochs is None:
raise ValueError(
"evals_per_epoch requires num_epochs to be set to calculate "
"evaluation intervals."
)
return data
class LoRAValidationMixin:
"""Validation methods related to LoRA/QLoRA configuration."""
@@ -1127,27 +1078,6 @@ class PretrainingValidationMixin:
data["accelerator_config"]["dispatch_batches"] = False
return data
@model_validator(mode="before")
@classmethod
def check_streaming_split_batches_accelerate(cls, data):
# Check if streaming is enabled for training
streaming = data.get("streaming", False)
# If streaming is enabled, configure accelerator
if streaming:
accelerator_config = data.get("accelerator_config", {})
if not accelerator_config:
data["accelerator_config"] = {
"split_batches": False,
"dispatch_batches": False,
}
else:
if accelerator_config.get("split_batches") is None:
data["accelerator_config"]["split_batches"] = False
if accelerator_config.get("dispatch_batches") is None:
data["accelerator_config"]["dispatch_batches"] = False
return data
class ModelCompatibilityValidationMixin:
"""Validation methods for specific model compatibility."""
@@ -1406,128 +1336,6 @@ class GRPOVllmValidationMixin:
return self
class StreamingValidationMixin:
"""Validation methods related to streaming datasets."""
def _is_streaming_enabled(self) -> bool:
"""Check if streaming is enabled."""
# Fall back to main streaming setting
streaming = getattr(self, "streaming", None)
if streaming is True:
return True
# Check if pretraining dataset exists (defaults to streaming)
has_pretraining = getattr(self, "pretraining_dataset", None) is not None
streaming = has_pretraining and streaming is None
return streaming
@model_validator(mode="after")
def check_streaming_requires_max_steps(self):
"""Ensure max_steps is set when using streaming datasets."""
# Check if streaming is enabled for training datasets
if self._is_streaming_enabled():
max_steps = getattr(self, "max_steps", None)
if not max_steps:
raise ValueError("max_steps must be set when using streaming datasets")
return self
@model_validator(mode="after")
def check_streaming_validation_splits_conflict(self):
"""Ensure validation splits are not used with streaming datasets."""
# Check if streaming is enabled for training datasets
if self._is_streaming_enabled():
val_set_size = getattr(self, "val_set_size", 0.0)
if val_set_size and val_set_size > 0:
raise ValueError(
"Validation splits not supported for streaming datasets, please "
"use test_datasets: ... instead"
)
return self
@model_validator(mode="after")
def check_streaming_preprocessing_conflict(self):
"""Ensure preprocessing is not enabled with streaming datasets."""
# Check if streaming is enabled for training datasets
if self._is_streaming_enabled():
if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1":
raise ValueError("preprocess is not supported for streaming datasets")
return self
@model_validator(mode="after")
def check_dataset_mixing_weights(self):
"""Validate dataset mixing weights configuration."""
valid_strategies = ["concatenate", "round_robin", "weighted", "random"]
# Get datasets to validate length against
datasets = getattr(self, "datasets", None)
# Check main strategy and weights
strategy = getattr(self, "dataset_mixing_strategy", "concatenate")
weights = getattr(self, "mixing_weights", None)
dataset_count = len(datasets) if datasets else 0
self._validate_dataset_strategy_and_weights(
strategy,
weights,
"dataset_mixing_strategy",
"mixing_weights",
valid_strategies,
dataset_count,
)
return self
def _validate_dataset_strategy_and_weights(
self,
strategy,
weights,
strategy_field,
weights_field,
valid_strategies,
dataset_count,
):
"""Helper method to validate dataset mixing strategy and weights pair."""
if strategy not in valid_strategies:
raise ValueError(
f"{strategy_field} must be one of {valid_strategies}, "
f"got '{strategy}'"
)
if strategy == "weighted":
if weights is None:
raise ValueError(
f"{weights_field} must be provided when "
f"{strategy_field}='weighted'"
)
if not isinstance(weights, list) or not all(
isinstance(w, (int, float)) for w in weights
):
raise ValueError(f"{weights_field} must be a list of numbers")
if any(w < 0 for w in weights):
raise ValueError(f"{weights_field} must be non-negative")
if abs(sum(weights) - 1.0) > 1e-6:
raise ValueError(f"{weights_field} must sum to 1.0, got {sum(weights)}")
# Validate weights length against dataset count
if dataset_count > 0 and len(weights) != dataset_count:
raise ValueError(
f"{weights_field} length ({len(weights)}) must match number of datasets ({dataset_count})"
)
elif weights is not None and strategy != "weighted":
LOG.warning(
f"{weights_field} provided but {strategy_field} is '{strategy}'. "
"Weights will be ignored."
)
# pylint: disable=too-many-ancestors
class ValidationMixin(
DatasetValidationMixin,
@@ -1539,7 +1347,6 @@ class ValidationMixin(
SystemValidationMixin,
ChatTemplateValidationMixin,
PretrainingValidationMixin,
StreamingValidationMixin,
ModelCompatibilityValidationMixin,
ComplexValidationMixin,
GRPOVllmValidationMixin,

View File

@@ -10,6 +10,7 @@ from typing import List, Optional
import numpy as np
import torch
import torch.cuda
from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
@@ -22,65 +23,6 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = get_logger(__name__)
def _create_filtered_iterable_dataset(dataset, filter_fn, batched=False):
"""
Create a filtered IterableDataset that works around a HuggingFace datasets
limitation.
"""
def filtered_generator():
"""Generator that yields only samples that pass the filter function."""
if batched:
batch = []
batch_size = 1000 # Process in batches of 1000
for sample in dataset:
batch.append(sample)
if len(batch) >= batch_size:
# Create a batch dict from list of samples
batch_dict = {}
for key in batch[0].keys():
batch_dict[key] = [sample[key] for sample in batch]
# Apply filter function to batch
keep_mask = filter_fn(batch_dict)
# Yield samples that should be kept
for i, keep in enumerate(keep_mask):
if keep:
yield batch[i]
batch = []
# Process remaining samples in batch
if batch:
batch_dict = {}
for key in batch[0].keys():
batch_dict[key] = [sample[key] for sample in batch]
keep_mask = filter_fn(batch_dict)
for i, keep in enumerate(keep_mask):
if keep:
yield batch[i]
else:
# For non-batched filtering, apply filter to each sample individually
for sample in dataset:
if filter_fn(sample):
yield sample
# Create new IterableDataset from the filtered generator
filtered_dataset = IterableDataset.from_generator(filtered_generator)
# Preserve the original features if they exist
# pylint:disable=protected-access
if hasattr(dataset, "_info") and dataset._info.features is not None:
filtered_dataset._info.features = dataset._info.features
return filtered_dataset
@torch.jit.script
def weighted_cross_entropy(
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
@@ -340,21 +282,12 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
# For IterableDatasets, always use custom filtering to avoid features issues
if isinstance(train_dataset, IterableDataset):
# IterableDatasets often have None features after transformations,
# so we use our custom filter implementation that doesn't rely on features
train_dataset = _create_filtered_iterable_dataset(
train_dataset, drop_no_trainable_tokens, batched=True
)
else:
train_dataset = train_dataset.filter(
drop_no_trainable_tokens,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
train_dataset = train_dataset.filter(
drop_no_trainable_tokens,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(train_dataset)
if dropped:
@@ -539,7 +472,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
)
data_loader = DataLoader(
train_dataset,
train_dataset.remove_columns(["length"]),
batch_sampler=sampler,
)
data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size
@@ -614,7 +547,7 @@ def setup_deepspeed_env(cfg, stage=None):
if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
# NOTE(djsaunde): The distributed state cannot be initialized prior to the
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
# to model load.
if (

View File

@@ -25,7 +25,7 @@ def min_cfg(temp_dir):
"liger_rms_norm": True,
"liger_glu_activation": True,
"torch_compile": True,
"chat_template": "qwen3",
"chat_template": "llama3",
"kd_trainer": True,
"kd_ce_alpha": 0.1,
"kd_alpha": 0.9,

View File

@@ -1,185 +0,0 @@
"""E2E tests for streaming dataset functionality"""
# pylint: disable=duplicate-code
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard
class TestStreamingDatasets:
"""Test case for streaming datasets with different mixing strategies"""
@pytest.mark.parametrize(
("dataset_mixing_strategy", "mixing_weights"),
[
("round_robin", None),
("weighted", [0.7, 0.3]),
("random", None),
],
)
def test_streaming_dataset_mixing_strategies(
self, temp_dir, dataset_mixing_strategy, mixing_weights
):
"""Test different mixing strategies with streaming datasets"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": True,
"sequence_len": 1024,
"sample_packing": False,
"dataset_processes": 1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
# Streaming config
"streaming": True,
"max_steps": 3, # Very small for smoke test
"dataset_mixing_strategy": dataset_mixing_strategy,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
}
)
# Add mixing weights if specified
if mixing_weights:
cfg["mixing_weights"] = mixing_weights
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
# Verify training actually happened by checking loss decrease
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
2.5, # Loss should be reasonable for a smoke test (higher threshold for streaming)
"Train Loss (%s) is too high",
)
def test_streaming_validation_error(self, temp_dir):
"""Test that pydantic validation catches invalid streaming configs"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"streaming": True,
"max_steps": 3,
# Invalid: wrong number of weights for datasets
"dataset_mixing_strategy": "weighted",
"mixing_weights": [1.0], # Should be [0.x, 0.y] for 2 datasets
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
}
)
# This should raise a validation error
with pytest.raises(Exception) as exc_info:
validate_config(cfg)
# Verify it's the right validation error
assert "mixing_weights length" in str(exc_info.value)
assert "must match number of datasets" in str(exc_info.value)
def test_streaming_three_datasets_weighted(self, temp_dir):
"""Test weighted mixing with three datasets"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": True,
"sequence_len": 512,
"sample_packing": False,
"dataset_processes": 1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
{
"path": "yahma/alpaca-cleaned",
"type": "alpaca",
},
],
# Streaming config
"streaming": True,
"max_steps": 3,
"dataset_mixing_strategy": "weighted",
"mixing_weights": [0.5, 0.3, 0.2],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
2.5,
"Train Loss (%s) is too high",
)

View File

@@ -7,13 +7,13 @@ from typing import Any, Generator
from unittest.mock import patch
import pytest
from datasets import Dataset, IterableDataset
from datasets import Dataset
from huggingface_hub import snapshot_download
from transformers import PreTrainedTokenizer
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets, prepare_datasets
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets
from axolotl.utils.dict import DictDefault
from tests.constants import (
@@ -24,7 +24,6 @@ from tests.constants import (
from tests.hf_offline_utils import enable_hf_offline
# pylint: disable=too-many-public-methods
class TestDatasetPreparation:
"""Test a configured dataloader."""
@@ -47,24 +46,6 @@ class TestDatasetPreparation:
]
)
@pytest.fixture
def streaming_dataset_fixture(self):
"""Create a streaming dataset fixture for testing."""
def generator():
yield {
"instruction": "Evaluate this sentence for spelling and grammar mistakes",
"input": "He finnished his meal and left the resturant",
"output": "He finished his meal and left the restaurant.",
}
yield {
"instruction": "What is the capital of France?",
"input": "",
"output": "The capital of France is Paris.",
}
return IterableDataset.from_generator(generator)
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
@enable_hf_offline
def test_load_hub(self, tokenizer):
@@ -505,162 +486,3 @@ class TestDatasetPreparation:
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
def test_streaming_sft_dataset(self, tokenizer, streaming_dataset_fixture):
"""Test streaming SFT dataset preparation with IterableDataset."""
with patch("axolotl.utils.data.sft.load_dataset_with_config") as mock_load:
mock_load.return_value = streaming_dataset_fixture
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 256,
"streaming": True,
"max_steps": 100, # Required for streaming datasets
"datasets": [
{
"path": "dummy/path",
"type": "alpaca",
},
],
}
)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
cfg, tokenizer
)
# Verify it returns an IterableDataset
assert isinstance(train_dataset, IterableDataset)
assert eval_dataset is None # No eval split for streaming
assert total_num_steps == 100 # Should use max_steps
assert len(prompters) == 1
# Test that we can iterate through the dataset
sample_count = 0
for sample in train_dataset:
assert "input_ids" in sample
assert "attention_mask" in sample
assert "labels" in sample
sample_count += 1
if sample_count >= 2: # Just test first few samples
break
assert sample_count == 2
def test_dataset_mixing_strategy_validation(self):
"""Test validation of dataset mixing strategy configuration."""
from axolotl.utils.data.shared import _merge_datasets_with_strategy
# Test valid strategies work
valid_strategies = ["round_robin", "weighted", "random"]
dataset1 = Dataset.from_dict({"text": ["a"], "source": ["ds1"]})
dataset2 = Dataset.from_dict({"text": ["b"], "source": ["ds2"]})
for strategy in valid_strategies:
cfg = DictDefault(
{
"dataset_mixing_strategy": strategy,
"mixing_weights": [0.5, 0.5] if strategy == "weighted" else None,
"seed": 42,
}
)
# Should not raise an error
merged = _merge_datasets_with_strategy([dataset1, dataset2], cfg)
assert len(merged) >= 1
def test_regular_dataset_round_robin_mixing(self):
"""Test round-robin mixing for regular datasets."""
from axolotl.utils.data.shared import _merge_datasets_with_strategy
# Create test datasets
dataset1 = Dataset.from_dict(
{"text": ["ds1_item1", "ds1_item2"], "source": ["ds1", "ds1"]}
)
dataset2 = Dataset.from_dict(
{"text": ["ds2_item1", "ds2_item2"], "source": ["ds2", "ds2"]}
)
cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42})
merged = _merge_datasets_with_strategy([dataset1, dataset2], cfg)
# Should have all samples from both datasets
assert len(merged) == 4
assert isinstance(merged, Dataset)
# Check that samples are interleaved (not just concatenated)
sources = [sample["source"] for sample in merged]
# Round-robin should alternate between datasets
assert sources != ["ds1", "ds1", "ds2", "ds2"] # Not concatenated
def test_regular_dataset_weighted_mixing(self):
"""Test weighted mixing for regular datasets."""
from axolotl.utils.data.shared import _merge_datasets_with_strategy
# Create test datasets
dataset1 = Dataset.from_dict(
{
"text": ["ds1_item1", "ds1_item2", "ds1_item3", "ds1_item4"],
"source": ["ds1"] * 4,
}
)
dataset2 = Dataset.from_dict(
{
"text": ["ds2_item1", "ds2_item2", "ds2_item3", "ds2_item4"],
"source": ["ds2"] * 4,
}
)
cfg = DictDefault(
{
"dataset_mixing_strategy": "weighted",
"mixing_weights": [0.75, 0.25], # 3:1 ratio
"seed": 42,
}
)
merged = _merge_datasets_with_strategy([dataset1, dataset2], cfg)
# Should have samples proportional to weights
assert len(merged) > 0
assert isinstance(merged, Dataset)
# Count samples from each dataset
sources = [sample["source"] for sample in merged]
ds1_count = sources.count("ds1")
ds2_count = sources.count("ds2")
# Should have samples from both datasets
assert ds1_count > 0 and ds2_count > 0 # Both datasets should be represented
def test_streaming_dataset_mixing(self):
"""Test that streaming datasets use HuggingFace interleave_datasets."""
from axolotl.utils.data.shared import _merge_datasets_with_strategy
# Create test streaming datasets
def gen1():
yield {"text": "stream1_item1", "source": "stream1"}
yield {"text": "stream1_item2", "source": "stream1"}
def gen2():
yield {"text": "stream2_item1", "source": "stream2"}
yield {"text": "stream2_item2", "source": "stream2"}
stream1 = IterableDataset.from_generator(gen1)
stream2 = IterableDataset.from_generator(gen2)
cfg = DictDefault({"dataset_mixing_strategy": "round_robin", "seed": 42})
merged = _merge_datasets_with_strategy([stream1, stream2], cfg)
# Should return an IterableDataset
assert isinstance(merged, IterableDataset)
# Test that we can iterate and get samples
samples = list(merged.take(3))
assert len(samples) >= 2 # Should get at least 2 samples
# Should have samples from both datasets
sources = [sample["source"] for sample in samples]
assert len(set(sources)) >= 1 # At least one unique source

View File

@@ -8,7 +8,7 @@ from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.completion import load
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import handle_long_seq_in_dataset
from axolotl.utils.data.utils import drop_long_seq_in_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -70,7 +70,7 @@ class TestBatchedSamplerPacking:
)
train_dataset = concatenate_datasets([dataset_wrapper])
train_dataset = handle_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
lengths = get_dataset_lengths(train_dataset)
batch_sampler = MultipackBatchSampler(

View File

@@ -1,11 +1,16 @@
"""Module for testing dataset sequence packing"""
import unittest
from pathlib import Path
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter
from axolotl.train import setup_model_and_trainer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
@@ -31,6 +36,43 @@ class TestPacking(unittest.TestCase):
}
)
def test_increments_attention(self):
prompter = AlpacaPrompter("chat")
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
dateset = load_dataset(
"json",
data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
)["train"]
dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
constant_len_dataset = ConstantLengthDataset(
self.tokenizer,
[dataset],
seq_length=2048,
)
packed_dataset = Dataset.from_list(list(constant_len_dataset))
example = packed_dataset[0]
next_bos_index = (
example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1
) # add one since we sliced
# first example doesn't have mask reset
assert example["input_ids"][0] == self.tokenizer.bos_token_id
assert example["attention_mask"][0] == 1
assert example["position_ids"][0] == 0
assert example["position_ids"][1] == 1
# but subsequent one does
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
assert example["attention_mask"][next_bos_index] == 2
assert example["position_ids"][next_bos_index] == 0
assert example["position_ids"][next_bos_index + 1] == 1
@with_temp_dir
def test_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code