Compare commits

..

3 Commits

Author SHA1 Message Date
Dan Saunders
d3bea3a2eb broken 2025-08-25 16:51:36 +00:00
Dan Saunders
2e2302aae3 remove unused 2025-08-25 15:46:25 +00:00
Dan Saunders
3a35076513 seems to be working? 2025-08-25 14:22:32 +00:00
76 changed files with 1093 additions and 1974 deletions

View File

@@ -12,6 +12,6 @@ reviews:
auto_review:
enabled: true
drafts: false
auto_incremental_review: false
auto_incremental_review: true
chat:
auto_reply: true

View File

@@ -36,11 +36,6 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -115,11 +110,6 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -179,12 +169,6 @@ jobs:
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
is_latest:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -36,15 +36,15 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras: vllm
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
pytorch: 2.7.1
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]

View File

@@ -55,7 +55,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
timeout-minutes: 20
steps:
@@ -130,7 +130,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
timeout-minutes: 20
steps:
@@ -240,7 +240,7 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
@@ -298,13 +298,6 @@ jobs:
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
num_gpus: 1
gpu_type: "B200"
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -325,7 +318,6 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "GPU_TYPE=${{ matrix.gpu_type || 'L40S'}}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
echo "E2E_DOCKERFILE=${{ matrix.dockerfile || 'Dockerfile.jinja'}}" >> $GITHUB_ENV
- name: Run tests job on Modal
@@ -342,10 +334,10 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 126
cuda_version: 12.6.3
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.7.1
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
steps:

View File

@@ -11,7 +11,7 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.12
rev: v0.12.9
hooks:
- id: ruff
args: [--fix]

View File

@@ -17,7 +17,6 @@
<br/>
<a href="https://discord.com/invite/HhrNrHJPRb"><img src="https://img.shields.io/badge/discord-7289da.svg?style=flat-square&logo=discord" alt="discord" style="height: 20px;"></a>
<a href="https://twitter.com/axolotl_ai"><img src="https://img.shields.io/twitter/follow/axolotl_ai?style=social" alt="twitter" style="height: 20px;"></a>
<a href="https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google-colab" style="height: 20px;"></a>
<br/>
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/tests-nightly.yml/badge.svg" alt="tests-nightly">
<img src="https://github.com/axolotl-ai-cloud/axolotl/actions/workflows/multi-gpu-e2e.yml/badge.svg" alt="multigpu-semi-weekly tests">
@@ -71,10 +70,6 @@ Features:
- Python 3.11
- PyTorch ≥2.6.0
### Google Colab
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa)
### Installation
#### Using pip

View File

@@ -153,7 +153,7 @@ quartodoc:
- utils.distributed
- utils.dict
- utils.optimizers.adopt
- utils.data.streaming
- utils.data.pretraining
- utils.data.sft
- utils.quantization
- title: Schemas
@@ -272,7 +272,6 @@ website:
contents:
- docs/batch_vs_grad.qmd
- docs/dataset_preprocessing.qmd
- docs/streaming.qmd
- docs/multipack.qmd
- docs/mixed_precision.qmd
- docs/optimizers.qmd

View File

@@ -57,8 +57,7 @@ VOLUME_CONFIG = {
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_TYPE = os.environ.get("GPU_TYPE", "L40S")
GPU_CONFIG = f"{GPU_TYPE}:{N_GPUS}"
GPU_CONFIG = f"L40S:{N_GPUS}"
def run_cmd(cmd: str, run_folder: str):

View File

@@ -12,7 +12,7 @@ coverage:
default:
# basic
target: auto
threshold: 1%
threshold: 0%
base: auto
# advanced
branches: null
@@ -27,7 +27,7 @@ coverage:
default:
# basic
target: auto
threshold: 1%
threshold: 0%
base: auto
# advanced
branches: null

View File

@@ -134,7 +134,7 @@ For providers supporting Docker:
### Google Colab {#sec-colab}
[![](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/axolotl-ai-cloud/axolotl/blob/main/examples/colab-notebooks/colab-axolotl-example.ipynb#scrollTo=msOCO4NRmRLa)
Use our [example notebook](../examples/colab-notebooks/colab-axolotl-example.ipynb).
## Platform-Specific Instructions {#sec-platform-specific}

View File

@@ -63,6 +63,15 @@ Start from Stage 1 -> Stage 2 -> Stage 3.
:::
::: {.callout-tip}
Using ZeRO Stage 3 with Single-GPU training
ZeRO Stage 3 can be used for training on a single GPU by manually setting the environment variables:
`WORLD_SIZE=1 LOCAL_RANK=0 MASTER_ADDR=0.0.0.0 MASTER_PORT=29500`
:::
## Fully Sharded Data Parallel (FSDP) {#sec-fsdp}
::: {.callout-note}

View File

@@ -11,7 +11,6 @@ We support the reward modelling techniques supported by `trl`.
### (Outcome) Reward Models
Outcome reward models are trained using data which contains preference annotations for an entire interaction between the user and model (e.g. rather than per-turn or per-step).
For improved training stability, you can use the `center_rewards_coefficient` parameter to encourage mean-zero reward outputs ([see TRL docs](https://huggingface.co/docs/trl/v0.10.1/en/reward_trainer#centering-rewards)).
```yaml
base_model: google/gemma-2-2b

View File

@@ -1,120 +0,0 @@
---
title: Streaming Datasets
description: How to use streaming mode for large-scale datasets and memory-efficient training
order: 10
---
Streaming enables memory-efficient training with large datasets by loading data
incrementally rather than loading the entire dataset into memory at once.
Use streaming when:
- Your dataset is too large to fit in memory (e.g. when you're doing pretraining with massive text corpora)
- You want to start training immediately without preprocessing the entire dataset
Streaming works with both remote and locally stored datasets!
::: {.callout-note}
Streaming currently only supports a single dataset. Multi-dataset support will be added soon.
:::
## Configuration
### Basic Streaming
Enable streaming mode by setting the `streaming` flag:
```yaml
streaming: true
```
### Pretraining with Streaming
For pretraining tasks, streaming is automatically enabled when using `pretraining_dataset`:
```yaml
pretraining_dataset:
- path: HuggingFaceFW/fineweb-edu
type: pretrain
text_column: text
split: train
# Optionally, enable sample packing
streaming_multipack_buffer_size: 10000
sample_packing: true
```
### SFT with Streaming
For supervised fine-tuning with streaming:
```yaml
streaming: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
split: train
# Optionally, enable sample packing
streaming_multipack_buffer_size: 10000
sample_packing: true
```
## Configuration Options
### `streaming_multipack_buffer_size`
Controls the buffer size for multipack streaming (default: 10,000). This determines how
many samples are buffered before packing. Larger buffers can improve packing efficiency
but use more memory.
### `shuffle_merged_datasets`
When enabled, shuffles the streaming dataset using the buffer. This requires additional
memory for the shuffle buffer.
## Sample Packing with Streaming
Sample packing is supported for streaming datasets. When enabled, multiple samples are
packed into a single sequence to maximize GPU utilization:
```yaml
sample_packing: true
streaming_multipack_buffer_size: 10000
# For SFT: attention is automatically isolated between packed samples
# For pretraining: control with pretrain_multipack_attn
pretrain_multipack_attn: true # prevent cross-attention between packed samples
```
For more information, see our [documentation](multipack.qmd) on multipacking.
## Important Considerations
### Memory Usage
While streaming reduces memory usage compared to loading entire datasets, you still need
to consider:
- You can control the memory usage by adjusting `streaming_multipack_buffer_size`
- Sample packing requires buffering multiple samples
- Shuffling requires additional memory for the shuffle buffer
### Performance
- Streaming may have slightly higher latency compared to preprocessed datasets, as samples are processed on-the-fly
- Network speed and disk read speed are important when streaming from remote sources or a local dataset, respectively
- Consider using `axolotl preprocess` for smaller or more frequently used datasets
### Evaluation Datasets
Evaluation datasets are not streamed to ensure consistent evaluation metrics. They're
loaded normally even when training uses streaming.
## Examples
See the `examples/streaming/` directory for complete configuration examples:
- `pretrain.yaml`: Pretraining with streaming dataset
- `sft.yaml`: Supervised fine-tuning with streaming

View File

@@ -1,10 +0,0 @@
provider: baseten
project_name:
secrets:
- HF_TOKEN
- WANDB_API_KEY
gpu: h100
gpu_count: 8
node_count: 1

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8\""
]
},
{

View File

@@ -1,68 +0,0 @@
base_model: google/gemma-3-270m-it
# optionally might have model_type or tokenizer_type
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
# gemma3 doesn't seem to play nice with ddp
ddp_find_unused_parameters: true
load_in_8bit: false
load_in_4bit: true
# huggingface repo
chat_template: gemma3
eot_tokens:
- <end_of_turn>
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
field_messages: conversations
message_property_mappings:
role: from
content: value
val_set_size: 0.0
output_dir: ./outputs/out
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
sequence_len: 2048
sample_packing: true
eval_sample_packing: false
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.1
evals_per_epoch:
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -1,44 +0,0 @@
base_model: Skywork/Skywork-Reward-V2-Qwen3-8B
model_type: AutoModelForSequenceClassification
num_labels: 1
reward_model: true
center_rewards_coefficient: 0.01 # Incentivize mean-zero rewards for improved stability
chat_template: qwen3
datasets:
- path: argilla/distilabel-intel-orca-dpo-pairs
type: bradley_terry.chat_template
val_set_size: 0.0
output_dir: ./outputs/out
sequence_len: 8192
sample_packing: false
eval_sample_packing: false
pad_to_sequence_len: true
deepspeed: deepspeed_configs/zero1.json
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
eval_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: linear
learning_rate: 0.00002
bf16: true
tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
warmup_ratio: 0.1
logging_steps: 1
weight_decay: 0.01

View File

@@ -1,50 +0,0 @@
# Streaming Dataset Examples
This directory contains example configurations for using Axolotl's streaming dataset
functionality, which enables memory-efficient training with large datasets.
## Examples
Run the following examples with e.g. `axolotl train examples/streaming/sft.yaml`; no
`axolotl preprocess` required!
### Pretraining (`pretrain.yaml`)
Demonstrates streaming configuration for pretraining tasks using the fineweb-edu dataset
with SmolLM2-135M.
- Uses `pretraining_dataset` configuration for automatic streaming
- Multipack attention control to prevent cross-attention between packed sequences
- Buffer size configuration for memory management
### SFT (`sft.yaml`)
Shows how to use streaming for supervised fine-tuning with the Alpaca dataset.
- Explicit `streaming: true` flag for SFT datasets
- Memory-efficient training on instruction datasets
- Evaluation datasets are currently not streamed
## Key Configuration Options
### `streaming`
- Enables streaming mode for standard datasets
- Automatically enabled for `pretraining_dataset`
### `streaming_multipack_buffer_size`
- Controls buffer size for sample packing (default: 10,000)
- Larger values improve packing efficiency but use more memory
- Adjust based on available memory
### `shuffle_merged_datasets`
- Enables shuffling of streaming datasets
- Requires additional memory for shuffle buffer
### `sample_packing`
- Packs multiple samples into single sequences
- Minimize per-step padding tokens
## Performance Tips
- Download small / frequently-used datasets locally for better performance
- Larger buffer sizes improve packing efficiency

View File

@@ -1,57 +0,0 @@
base_model: HuggingFaceTB/SmolLM2-135M
# Streaming pretraining configuration
pretraining_dataset:
- path: HuggingFaceFW/fineweb-edu
name: sample-10BT
type: pretrain
text_column: text
split: train
# Streaming-specific settings
streaming_multipack_buffer_size: 10000
shuffle_merged_datasets: true
# Training configuration
max_steps: 1000
output_dir: ./outputs/smollm2-135m-pretrain-streaming
# Sequence and packing settings
sequence_len: 1024
sample_packing: true
pretrain_multipack_attn: true # Prevent cross-attention between packed sequences
flash_attention: true
# Batch size settings
gradient_accumulation_steps: 8
micro_batch_size: 1
# Optimizer and scheduler
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 5e-4
warmup_ratio: 0.1
weight_decay: 0.01
# Precision and performance
bf16: auto
tf32: true
# Logging and checkpointing
logging_steps: 10
save_strategy: steps
save_steps: 250
save_total_limit: 3
# Weights & Biases (optional)
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# Special tokens
special_tokens:
pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -1,55 +0,0 @@
base_model: HuggingFaceTB/SmolLM2-135M
# Dataset configuration
datasets:
- path: tatsu-lab/alpaca
type: alpaca
split: train
# Streaming-specific settings
streaming: true
streaming_multipack_buffer_size: 10000
shuffle_merged_datasets: true
# Training configuration
max_steps: 1000
output_dir: ./outputs/smollm2-135m-sft-streaming
# Sequence and packing settings
sequence_len: 1024
sample_packing: true
flash_attention: true
# Batch size settings
gradient_accumulation_steps: 4
micro_batch_size: 1
# Optimizer and scheduler
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-4
warmup_ratio: 0.1
weight_decay: 0.0
# Precision and performance
bf16: auto
tf32: true
# Logging and checkpointing
logging_steps: 10
save_strategy: steps
save_steps: 100
save_total_limit: 3
# Weights & Biases (optional)
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
# Special tokens
special_tokens:
pad_token: "<|endoftext|>"
# save_first_step: true # uncomment this to validate checkpoint saving works with your config

View File

@@ -2,7 +2,8 @@
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.47.0
triton>=3.0.0
# triton 3.4.0 is not compatible with CCE
triton>=3.0.0,<3.4.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
@@ -13,7 +14,7 @@ packaging==23.2
huggingface_hub>=0.33.0
peft>=0.17.0
transformers==4.56.1
transformers==4.55.3
tokenizers>=0.21.1
accelerate==1.10.0
datasets==4.0.0

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"'
)

View File

@@ -64,9 +64,7 @@ def parse_requirements(extras_require_map):
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 8):
pass
elif (major, minor) >= (2, 7):
if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.30")
@@ -127,7 +125,7 @@ extras_require = {
"yunchang==0.6.0",
],
"deepspeed": [
"deepspeed==0.17.5",
"deepspeed==0.17.2",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -14,13 +14,9 @@ class PreprocessCliArgs:
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(
default=False,
default=None,
metadata={
"help": (
"Deprecated in v0.13.0, will be removed in v0.14.0. For streaming "
"datasets, use 'axolotl train' and set 'streaming: true' in your YAML "
"config, or pass --streaming instead in the CLI."
)
"help": "Use IterableDataset for streaming processing of large datasets"
},
)

View File

@@ -7,8 +7,6 @@ from typing import Literal
import yaml
from axolotl.cli.cloud.base import Cloud
from axolotl.cli.cloud.baseten import BasetenCloud
from axolotl.cli.cloud.modal_ import ModalCloud
from axolotl.utils.dict import DictDefault
@@ -40,15 +38,8 @@ def do_cli_train(
cwd=None,
**kwargs,
) -> None:
cloud_cfg: DictDefault = load_cloud_cfg(cloud_config)
provider = cloud_cfg.provider or "modal"
cloud: Cloud | None
if provider == "modal":
cloud = ModalCloud(cloud_cfg)
elif provider == "baseten":
cloud = BasetenCloud(cloud_cfg.to_dict())
else:
raise ValueError(f"Unsupported cloud provider: {provider}")
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
local_dirs = {}

View File

@@ -1,48 +0,0 @@
"""Baseten Cloud CLI"""
import shutil
import subprocess # nosec B404
import tempfile
from os.path import dirname
from typing import Literal
import yaml
from axolotl.cli.cloud.base import Cloud
class BasetenCloud(Cloud):
"""Baseten Cloud Axolotl CLI"""
def __init__(self, config: dict):
self.config = config
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
raise NotImplementedError(
"Separate preprocess function for Baseten is not "
"implemented and will happen during hte train step."
)
def train(
self,
config_yaml: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
local_dirs: dict[str, str] | None = None, # pylint: disable=unused-argument
**kwargs,
):
with tempfile.TemporaryDirectory() as tmp_dir:
config = self.config.copy()
config["launcher"] = launcher
config["launcher_args"] = launcher_args
with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout:
yaml.dump(config, cloud_fout)
with open(tmp_dir + "/train.yaml", "w", encoding="utf-8") as config_fout:
config_fout.write(config_yaml)
shutil.copyfile(dirname(__file__) + "/template/run.sh", tmp_dir + "/run.sh")
shutil.copyfile(
dirname(__file__) + "/template/train_sft.py", tmp_dir + "/train_sft.py"
)
subprocess.run( # nosec B603 B607
["truss", "train", "push", "train_sft.py"], cwd=tmp_dir, check=False
)

View File

@@ -1,9 +0,0 @@
#!/bin/bash
set -eux
export NCCL_SOCKET_IFNAME="^docker0,lo"
export NCCL_IB_DISABLE=0
export NCCL_TIMEOUT=1800000
axolotl preprocess train.yaml
axolotl train train.yaml --launcher ${AXOLOTL_LAUNCHER} ${AXOLOTL_LAUNCHER_ARGS}

View File

@@ -1,71 +0,0 @@
"""
Baseten Training Script for Axolotl
"""
# pylint: skip-file
import yaml
from truss.base import truss_config
# Import necessary classes from the Baseten Training SDK
from truss_train import definitions
cloud_config = yaml.safe_load(open("cloud.yaml", "r"))
gpu = cloud_config.get("gpu", "h100")
gpu_count = int(cloud_config.get("gpu_count", 1))
node_count = int(cloud_config.get("node_count", 1))
project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project"
secrets = cloud_config.get("secrets", [])
launcher = cloud_config.get("launcher", "accelerate")
launcher_args = cloud_config.get("launcher_args", [])
script_name = "run.sh"
launcher_args_str = ""
if launcher_args:
launcher_args_str = "-- " + " ".join(launcher_args)
# 1. Define a base image for your training job
# must use torch 2.7.0 for vllm
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
# 2. Define the Runtime Environment for the Training Job
# This includes start commands and environment variables.a
# Secrets from the baseten workspace like API keys are referenced using
# `SecretReference`.
env_vars = {
"AXOLOTL_LAUNCHER": launcher,
"AXOLOTL_LAUNCHER_ARGS": launcher_args_str,
}
for secret_name in secrets:
env_vars[secret_name] = definitions.SecretReference(name=secret_name)
training_runtime = definitions.Runtime(
start_commands=[ # Example: list of commands to run your training script
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
],
environment_variables=env_vars,
)
# 3. Define the Compute Resources for the Training Job
training_compute = definitions.Compute(
node_count=node_count,
accelerator=truss_config.AcceleratorSpec(
accelerator=truss_config.Accelerator.H100,
count=gpu_count,
),
)
# 4. Define the Training Job
# This brings together the image, compute, and runtime configurations.
my_training_job = definitions.TrainingJob(
image=definitions.Image(base_image=BASE_IMAGE),
compute=training_compute,
runtime=training_runtime,
)
# This config will be pushed using the Truss CLI.
# The association of the job to the project happens at the time of push.
first_project_with_job = definitions.TrainingProject(
name=project_name, job=my_training_job
)

View File

@@ -14,7 +14,10 @@ from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.cli.args import InferenceCliArgs
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
@@ -61,9 +64,7 @@ def do_inference(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template_from_config(
cfg, ds_cfg=None, tokenizer=tokenizer
)
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
@@ -158,13 +159,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template_from_config(
cfg, ds_cfg=None, tokenizer=tokenizer
)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
model = model.to(cfg.device, dtype=cfg.torch_dtype)

View File

@@ -43,10 +43,7 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
if processor:
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))

View File

@@ -35,20 +35,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
check_accelerate_default_config()
check_user_token()
if cli_args.iterable:
LOG.error(
"The --iterable CLI argument for 'axolotl preprocess' is no longer "
"supported. For training, set 'streaming: true' in your YAML config or "
"pass '--streaming' in your 'axolotl train' command for on-the-fly "
"preprocessing."
)
return
for key in ["skip_prepare_dataset", "pretraining_dataset"]:
if cfg.get(key):
LOG.error(
f"You have set `{key}:`. `preprocess` is not needed. Run the 'axolotl "
"train' CLI directly instead."
f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead."
)
return

View File

@@ -84,6 +84,5 @@ def do_quantize(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")

View File

@@ -55,11 +55,13 @@ def load_datasets(
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = getattr(cli_args, "iterable", False)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
cfg,
tokenizer,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if (

View File

@@ -24,7 +24,9 @@ from pathlib import Path
from typing import Any
import torch
from transformers import TrainerCallback
from transformers import (
TrainerCallback,
)
from transformers.trainer_pt_utils import AcceleratorConfig
from axolotl.integrations.base import PluginManager
@@ -510,7 +512,6 @@ class TrainerBuilderBase(abc.ABC):
self.cfg.eval_batch_size
)
training_args_kwargs["include_tkps"] = self.cfg.include_tkps
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs

View File

@@ -7,7 +7,10 @@ from pathlib import Path
from typing import Type, Union
import transformers
from transformers import DataCollatorWithFlattening, EarlyStoppingCallback
from transformers import (
DataCollatorWithFlattening,
EarlyStoppingCallback,
)
from trl.trainer.utils import RewardDataCollatorWithPadding
from axolotl.core.builders.base import TrainerBuilderBase
@@ -23,12 +26,12 @@ from axolotl.monkeypatch.relora import ReLoRACallback
from axolotl.processing_strategies import get_processing_strategy
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
log_prediction_callback_factory,
LossWatchDogCallback,
SaveBetterTransformerModelCallback,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.qat import QATCallback
@@ -39,7 +42,6 @@ from axolotl.utils.collators import (
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.import_helper import get_cls_from_module_str
from axolotl.utils.logging import get_logger
@@ -72,12 +74,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.qat:
callbacks.append(QATCallback(self.cfg.qat))
if self.cfg.include_tkps:
callbacks.append(
TokensPerSecondCallback(
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
)
)
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -344,10 +340,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
if self.cfg.center_rewards_coefficient is not None:
training_arguments_kwargs["center_rewards_coefficient"] = (
self.cfg.center_rewards_coefficient
)
elif self.cfg.process_reward_model:
training_args_cls = AxolotlPRMConfig
else:
@@ -412,9 +404,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
# if the trainer has the `axolotl_cfg` property, set it
if hasattr(trainer, "axolotl_cfg"):
trainer.axolotl_cfg = self.cfg
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)

View File

@@ -42,7 +42,6 @@ from axolotl.core.trainers.utils import (
)
from axolotl.utils import get_not_null
from axolotl.utils.bench import get_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -64,15 +63,6 @@ class AxolotlTrainer(
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
_axolotl_cfg: DictDefault | None = None
@property
def axolotl_cfg(self):
return self._axolotl_cfg
@axolotl_cfg.setter
def axolotl_cfg(self, cfg):
self._axolotl_cfg = cfg
def __init__(
self,
@@ -88,6 +78,7 @@ class AxolotlTrainer(
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
@@ -336,17 +327,6 @@ class AxolotlTrainer(
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
# track number of tokens for tokens per second calculation
if self.args.include_tkps:
inputs_key = "labels" if "labels" in inputs else "input_ids"
if hasattr(self.state, "num_tokens"):
self.state.num_tokens = (
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
)
else:
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
@@ -546,6 +526,9 @@ class AxolotlTrainer(
super().create_accelerator_and_postprocess()
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
if self.is_fsdp_enabled:
if (
"limit_all_gathers" in self.args.fsdp_config
@@ -593,19 +576,12 @@ class AxolotlTrainer(
# Add memory usage
try:
active, allocated, reserved = get_gpu_memory_usage()
logs["memory/max_active (GiB)"] = round(active, 2)
logs["memory/max_allocated (GiB)"] = round(allocated, 2)
logs["memory/device_reserved (GiB)"] = round(reserved, 2)
logs["memory/max_mem_active(gib)"] = round(active, 2)
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
except (ValueError, TypeError, FileNotFoundError):
pass
if self.args.include_tkps and train_eval == "train":
# each rank will log its own tokens per second
# for logging_steps > 1 we obtain a moving average of this metric
logs["tokens_per_second_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
del self._stored_metrics[train_eval]
return super().log(logs, start_time)
@@ -681,11 +657,6 @@ class AxolotlTrainer(
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
self.data_collator.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -3,14 +3,11 @@ Trainer mixin for activation checkpointing w offloading
"""
import contextlib
from functools import partial
from peft import PeftModel
from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from transformers import GradientCheckpointingLayer, Trainer
@@ -49,20 +46,9 @@ class ActivationOffloadingMixin(Trainer):
return super().training_step(*args, **kwargs)
def ac_wrap_hf_model(model: nn.Module, use_reentrant=None, **kwargs):
def ac_wrap_hf_model(model: nn.Module, **kwargs):
auto_wrap_policy = ModuleWrapPolicy(set((GradientCheckpointingLayer,)))
if use_reentrant:
checkpoint_wrapper_fn = partial(
checkpoint_wrapper, checkpoint_impl=CheckpointImpl.REENTRANT
)
else:
checkpoint_wrapper_fn = checkpoint_wrapper
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=checkpoint_wrapper_fn,
auto_wrap_policy=auto_wrap_policy,
**kwargs,
)
apply_activation_checkpointing(model, auto_wrap_policy=auto_wrap_policy, **kwargs)
def get_lora_act_offloading_ctx_manager(

View File

@@ -49,12 +49,6 @@ class AxolotlTrainingMixins:
default=False,
metadata={"help": "Use real batches for efficient training."},
)
include_tkps: bool = field(
default=True,
metadata={
"help": "Whether to include tokens per second in the training metrics."
},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},

View File

@@ -1,17 +1,18 @@
"""
Module containing dataset functionality.
We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
concept of middlewares to wrap each dataset. We'll use the collators later on to pad the
datasets.
"""
"""Module containing Dataset functionality"""
import torch
from datasets import Dataset, IterableDataset
from axolotl.utils.logging import get_logger
from .prompt_tokenizers import PromptTokenizingStrategy
# We want this to be a wrapper for an existing dataset that we have loaded
# lets use the concept of middlewares to wrap each dataset, for example
# ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
# let's check to ensure we don't truncate an item in the middle, we'll use
# the collators later on to pad the datasets
LOG = get_logger(__name__)
@@ -85,3 +86,133 @@ def wrap_dataset_for_tokenized_prompt(
**map_kwargs,
)
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
"""Iterable dataset that returns constant length chunks of tokens from stream of
text files.
Args:
tokenizer: The processor used for processing the data.
dataset: Dataset with text files.
seq_length: Length of token sequences to return.
"""
def __init__(
self,
tokenizer,
datasets,
seq_length=2048,
):
self.tokenizer = tokenizer
self.concat_token_id = tokenizer.eos_token_id
self.datasets: list[IterableDataset] = datasets
self.seq_length = seq_length
vocab_size = len(tokenizer.get_vocab())
if vocab_size <= torch.iinfo(torch.int16).max:
self.tokens_dtype = torch.int16
elif vocab_size <= torch.iinfo(torch.int32).max:
self.tokens_dtype = torch.int32
else:
self.tokens_dtype = torch.int64
def __iter__(self):
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
for dataset in self.datasets:
idx = 0
iterator = iter(dataset)
more_examples = True
while more_examples:
try:
example = next(iterator)
idx += 1
except StopIteration:
more_examples = False
example = None
add_concat_token = False
if example:
example_len = len(example["input_ids"])
add_concat_token = example["input_ids"][-1] != self.concat_token_id
else:
example_len = 0
if not example_len or (
buffer_len + int(add_concat_token) + example_len > self.seq_length
):
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
: self.seq_length
]
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
):
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
LOG.warning(
"Dropping batch due to tensor size mismatch "
f"input_ids: {input_ids.size()}, "
f"labels: {labels.size()}, "
f"attention_mask: {attention_mask.size()}"
)
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
idx = 1
if example:
# FIXME
# just going to drop data points that are too long
if len(example["input_ids"]) <= self.seq_length:
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if add_concat_token:
input_ids.append(self.concat_token_id)
attention_mask.append(1)
labels.append(self.concat_token_id)
input_ids_with_concat = torch.tensor(
input_ids, dtype=self.tokens_dtype
)
attention_mask_with_concat = torch.tensor(
[idx * m for m in attention_mask], dtype=torch.int16
)
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids)

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"
```
## Usage

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"`'
)

View File

@@ -98,8 +98,6 @@ def load_lora(
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora
if cfg.peft_layer_replication:
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication
if cfg.peft_trainable_token_indices:
lora_config_kwargs["trainable_token_indices"] = cfg.peft_trainable_token_indices
lora_config = LoraConfig(
r=cfg.lora_r,

View File

@@ -224,27 +224,21 @@ class ModelLoader:
):
self.model = self.model.merge_and_unload()
use_reentrant = None
if (
self.cfg.gradient_checkpointing_kwargs
and self.cfg.gradient_checkpointing_kwargs.get("use_reentrant", True)
):
use_reentrant = True
self._apply_activation_checkpointing(use_reentrant=use_reentrant)
self._apply_activation_checkpointing()
self._resize_token_embeddings()
self._adjust_model_config()
self._configure_embedding_dtypes()
self._configure_qat()
log_gpu_memory_usage(LOG, "Memory usage after model load", 0)
def _apply_activation_checkpointing(self, use_reentrant: bool | None = None):
def _apply_activation_checkpointing(self):
if self.cfg.activation_offloading is True:
from axolotl.core.trainers.mixins.activation_checkpointing import (
ac_wrap_hf_model,
)
# ^^ importing this at the module level breaks plugins
ac_wrap_hf_model(self.model, use_reentrant=use_reentrant)
ac_wrap_hf_model(self.model)
def _resize_token_embeddings(self):
"""Resize token embeddings if needed."""

View File

@@ -3,7 +3,6 @@
Applies pre- and post-model load patches for various fixes and optimizations.
"""
import os
import importlib.util
from functools import cached_property
@@ -67,7 +66,6 @@ class PatchManager:
self._apply_mistral_cross_entropy_patch()
self._apply_self_attention_lora_patch()
self._apply_fsdp2_bnb_patches()
self._apply_patch_deepspeed_zero3()
def apply_post_plugin_pre_model_load_patches(self):
"""Apply post plugin-pre_model_load load patches based on config."""
@@ -80,7 +78,13 @@ class PatchManager:
patch_maybe_log_save_evaluate,
)
patch_evaluation_loop()
patch_fsdp2 = (
self.cfg.torch_compile
and self.cfg.fsdp_config
and self.cfg.fsdp_version == 2
)
patch_evaluation_loop(patch_fsdp2)
patch_maybe_log_save_evaluate()
def apply_post_model_load_patches(self, model: PreTrainedModel):
@@ -143,12 +147,14 @@ class PatchManager:
def _apply_flex_attention_patches(self):
"""Apply patches for flexible attention."""
if self.cfg.flex_attention:
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_wrapper,
)
flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
patch_flex_wrapper(**flex_attn_compile_kwargs)
# from axolotl.monkeypatch.attention.flex_attn import (
# patch_flex_make_mask,
# patch_flex_wrapper,
# )
#
# flex_attn_compile_kwargs = self.cfg.flex_attn_compile_kwargs or {}
# patch_flex_wrapper(**flex_attn_compile_kwargs)
# patch_flex_make_mask()
if self.cfg.sample_packing:
from axolotl.core.attention.flex_block_mask import (
patch_create_causal_mask,
@@ -465,16 +471,3 @@ class PatchManager:
from axolotl.monkeypatch.lora_kernels import apply_lora_kernel_patches
apply_lora_kernel_patches(model=model, cfg=self.cfg)
def _apply_patch_deepspeed_zero3(self):
try:
from axolotl.monkeypatch.deepspeed_utils import apply_deepspeed_patches
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
if self.cfg.activation_offloading is True and (
is_deepspeed_zero3_enabled()
or os.getenv("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3"
):
apply_deepspeed_patches()
except ImportError as e:
LOG.warning(f"DeepSpeed patches not applied: {e}")

View File

@@ -1,11 +1,11 @@
"""Flex attention monkey patch"""
import sys
from packaging import version
from typing import Optional, Tuple, Union
import torch
import transformers
from transformers.utils.import_utils import _torch_version, is_torch_less_or_equal
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@@ -46,33 +46,19 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
"""
self.training = None
if not self._is_flex_compiled or training != self.training:
self.training = training
if is_torch_less_or_equal("2.5.1"):
self._compiled_flex_attention = torch.compile(
flex_attention, dynamic=False
)
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
# see https://github.com/pytorch/pytorch/issues/146260 for training
elif version.parse(_torch_version).base_version == "2.6.0" and training:
self._compiled_flex_attention = torch.compile(
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
)
# Fallback, usually the most recent torch 2.7.x+ versions
else:
LOG.info(
"Compiling flex attention with kwargs: %s. This may take a while...",
flex_attn_compile_kwargs,
main_process_only=True,
)
self._compiled_flex_attention = torch.compile(
flex_attention,
**flex_attn_compile_kwargs,
)
LOG.info(
"Flex attention compiled successfully.", main_process_only=True
)
self.training = training
LOG.info(
"Compiling flex attention with kwargs: %s. This may take a while...",
flex_attn_compile_kwargs,
)
self._compiled_flex_attention = torch.compile(
flex_attention,
**flex_attn_compile_kwargs,
)
LOG.info("Flex attention compiled successfully.")
self._is_flex_compiled = True
def __call__(self):
@@ -82,3 +68,139 @@ def patch_flex_wrapper(**flex_attn_compile_kwargs):
sys.modules[
"transformers.integrations.flex_attention"
].WrappedFlexAttention = WrappedFlexAttention
def patch_flex_make_mask():
is_torch_2_6 = torch.__version__.startswith("2.6")
if not is_torch_2_6:
return
from torch.nn.attention.flex_attention import (
_DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size,
)
from torch.nn.attention.flex_attention import (
BlockMask,
)
from torch.nn.attention.flex_attention import (
create_block_mask as create_block_causal_mask_flex,
)
Offset = Union[torch.Tensor, int]
def patched_make_flex_block_causal_mask(
attention_mask_2d: torch.Tensor,
attention_chunk_size: Optional[int] = None,
query_length=None,
key_length=None,
offsets: Optional[Tuple[Offset, Offset]] = None,
) -> "BlockMask":
"""
Create a block causal document mask for a batch of sequences, both packed and unpacked.
Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
The resultant BlockMask is a compressed representation of the full block causal
mask. BlockMask is essential for performant computation of flex attention.
See: https://pytorch.org/blog/flexattention/
Args:
attention_mask_2d (torch.Tensor): Attention mask for packed and padded sequences
of shape (batch_size, total_seq_len). e.g.
For unpacked sequence:
[[1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0]]
For packed sequence:
[[1, 1, 1, 2, 2, 2, 0],
[1, 1, 2, 2, 2, 3, 3]]
Returns:
BlockMask
"""
batch_size, total_seq_len = attention_mask_2d.shape
if not key_length:
key_length = total_seq_len
if not query_length:
query_length = total_seq_len
attention_mask_2d = torch.nn.functional.pad(
attention_mask_2d,
value=0,
pad=(0, abs(total_seq_len - max(key_length, flex_default_block_size))),
)
device = attention_mask_2d.device
document_ids = attention_mask_2d.clone()
if attention_chunk_size is not None:
# we create an arange, then we just // by chunk size to get [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
chunk_idxs = (document_ids.clone().fill_(1).cumsum(-1) - 1) // (
attention_chunk_size
)
# Instead of passing a tensor mask, flex attention requires a mask_mod function
# that determines which elements of QK^T should be included in the attention
# computation prior to the softmax. For sample packing, we need both the
# logic for both causal mask and document mask. See PyTorch's official
# blog post for more details: https://pytorch.org/blog/flexattention/#mask-mods
def causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"""
Defines the logic of a block causal mask by combining both a standard causal mask
and a block diagonal document mask.
See :func:`~torchtune.modules.attention_utils.create_block_causal_mask`
for an illustration.
"""
causal_mask = q_idx >= kv_idx # not valid when decoding
document_mask = (
document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
)
padding_mask = attention_mask_2d[batch_idx, q_idx] > 0
final_mask = causal_mask & padding_mask & document_mask
return final_mask
def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
"""
Combines the chunk mask with the causal mask for chunked attention.
"""
chunk_mask = chunk_idxs[batch_idx, q_idx] == chunk_idxs[batch_idx, kv_idx]
causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx)
return chunk_mask & causal_doc_mask
mask_mod_maybe_combined = (
causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod
)
if offsets is not None:
q_offset = offsets[0]
kv_offset = offsets[1]
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
offset_q = q_idx + q_offset
offset_kv = kv_idx + kv_offset
return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv)
else:
mask_mod = mask_mod_maybe_combined
return create_block_causal_mask_flex(
mask_mod=mask_mod,
B=batch_size,
H=None, # attention head
Q_LEN=query_length,
KV_LEN=key_length,
device=device,
_compile=True,
)
for n in tuple(sys.modules):
if ".modeling_" in n:
if hasattr(sys.modules[n], "make_flex_block_causal_mask"):
sys.modules[
n
].make_flex_block_causal_mask = patched_make_flex_block_causal_mask
sys.modules[
n
].make_flex_block_causal_mask = patched_make_flex_block_causal_mask
transformers.integrations.flex_attention.make_flex_block_causal_mask = (
patched_make_flex_block_causal_mask
)

View File

@@ -1,66 +0,0 @@
import importlib
import importlib.util
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
def patch_checkpoint_wrapper_setattr():
"""
Patch CheckpointWrapper to properly forward DeepSpeed attributes to wrapped modules.
This fixes the issue where CheckpointWrapper doesn't forward ds_* attributes
(like ds_grads_remaining) to the actual wrapped module, causing DeepSpeed
ZeRO-3 to fail when gradient checkpointing is enabled.
This issue occurs specifically with:
- QLoRA + DeepSpeed ZeRO-3
- gradient_checkpointing: true
- activation_offloading: true
References:
- https://github.com/deepspeedai/DeepSpeed/issues/7203
- https://github.com/deepspeedai/DeepSpeed/blob/38d1a9eb64c9e01e32eccc50b25ba18925287441/deepspeed/runtime/zero/parameter_offload.py#L424-L458
- https://github.com/axolotl-ai-cloud/axolotl/pull/3102
"""
try:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointWrapper,
)
# Check if already patched
if hasattr(CheckpointWrapper, "_axolotl_setattr_patched"):
LOG.debug("CheckpointWrapper already patched")
return
original_setattr = CheckpointWrapper.__setattr__
def new_setattr(self, name: str, value) -> None:
if name.startswith("ds_") and hasattr(self, "_checkpoint_wrapped_module"):
setattr(self._checkpoint_wrapped_module, name, value)
LOG.debug(
f"Forwarded {name} to wrapped module {type(self._checkpoint_wrapped_module).__name__}"
)
else:
original_setattr(self, name, value)
CheckpointWrapper.__setattr__ = new_setattr
CheckpointWrapper._axolotl_setattr_patched = True
LOG.info("CheckpointWrapper patched to forward DeepSpeed attributes")
except ImportError as e:
LOG.debug(f"CheckpointWrapper not available: {e}")
except Exception as e:
LOG.warning(f"Failed to patch CheckpointWrapper: {e}")
def apply_deepspeed_patches():
"""
Apply DeepSpeed-related patches
"""
if importlib.util.find_spec("deepspeed") is not None:
patch_checkpoint_wrapper_setattr()
else:
LOG.debug("DeepSpeed not available, skipping patches")

View File

@@ -149,11 +149,6 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
return MistralAttention
if model_type == "gemma3_text":
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention
return Gemma3Attention
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"

View File

@@ -8,94 +8,6 @@ from typing import List
import torch
class DeepSpeedTiledMLPMoE(torch.autograd.Function):
@staticmethod
def forward(
ctx,
fn,
self,
x,
shards,
compute_params,
) -> torch.Tensor:
ctx.fn = fn
ctx.self = self
ctx.shards = shards
ctx.compute_params = [p for p in compute_params if p.requires_grad]
ctx.save_for_backward(x)
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
with torch.no_grad():
output_shards = [fn(self, x_shard) for x_shard in x_shards]
ctx.is_tuple_output = isinstance(output_shards[0], tuple)
if isinstance(output_shards[0], tuple):
tuple_dim_idx = [1, 0]
output_unsharded = tuple(
torch.cat(
[output_shard[i] for output_shard in output_shards],
dim=tuple_dim_idx[i],
)
for i in range(len(output_shards[0]))
)
else:
output_unsharded = torch.cat(output_shards, dim=1)
return output_unsharded
@staticmethod
def backward(ctx, *grads) -> torch.Tensor:
fn = ctx.fn
(x,) = ctx.saved_tensors
self = ctx.self
shards = ctx.shards
compute_params = ctx.compute_params
is_tuple_output = ctx.is_tuple_output
x_requires_grad = x.requires_grad
x = x.detach()
# detach() unsets `x.requires_grad`, so restore it
x.requires_grad_(x_requires_grad)
incoming_grad = grads[0]
x_grad = torch.zeros_like(x)
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
shard_step = x_shards[0].numel()
for i, x_shard in enumerate(x_shards):
# Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run
if compute_params is not None:
if i + 1 < shards:
for param in compute_params:
param.ds_grad_is_ready = False
else:
# last shard, can add the grad
for param in compute_params:
param.ds_grad_is_ready = True
x_shard.requires_grad_(x_requires_grad)
shard_offset = i * shard_step
x_shard.grad = (
x_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
incoming_grad_shard = (
incoming_grad.view(-1)
.narrow(0, shard_offset, x_shard.numel())
.view_as(x_shard)
)
with torch.enable_grad():
output = fn(self, x_shard)
if is_tuple_output:
torch.autograd.backward(output[0], incoming_grad_shard)
else:
torch.autograd.backward(output, incoming_grad_shard)
return (None, None, x_grad, None, None)
class TiledMLP(torch.autograd.Function):
"""
TiledMLP implementation using gradient hooks
@@ -119,18 +31,7 @@ class TiledMLP(torch.autograd.Function):
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
with torch.no_grad():
output_shards = [fn(self, x_shard) for x_shard in x_shards]
ctx.is_tuple_output = isinstance(output_shards[0], tuple)
if isinstance(output_shards[0], tuple):
tuple_dim_idx = [1, 0]
output_unsharded = tuple(
torch.cat(
[output_shard[i] for output_shard in output_shards],
dim=tuple_dim_idx[i],
)
for i in range(len(output_shards[0]))
)
else:
output_unsharded = torch.cat(output_shards, dim=1)
output_unsharded = torch.cat(output_shards, dim=1)
return output_unsharded
@@ -141,7 +42,6 @@ class TiledMLP(torch.autograd.Function):
self = ctx.self
shards = ctx.shards
compute_params = ctx.compute_params
is_tuple_output = ctx.is_tuple_output
x_requires_grad = x.requires_grad
x = x.detach()
@@ -176,10 +76,7 @@ class TiledMLP(torch.autograd.Function):
with torch.enable_grad():
output = fn(self, x_shard)
if is_tuple_output:
torch.autograd.backward(output[0], incoming_grad_shard)
else:
torch.autograd.backward(output, incoming_grad_shard)
torch.autograd.backward(output, incoming_grad_shard)
# Clean up hooks
grad_accumulator.cleanup()

View File

@@ -17,7 +17,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
TiledMLP as DeepSpeedTiledMLP,
)
from axolotl.monkeypatch.tiled_mlp.base import DeepSpeedTiledMLPMoE, TiledMLP
from axolotl.monkeypatch.tiled_mlp.base import TiledMLP
try:
# Dynamically import the module and MLP class
@@ -64,10 +64,7 @@ def patch_tiled_mlp(model_type, use_original_mlp=True, cfg_num_shards=None):
for p in self._compute_params
)
) or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
if model_type == "gpt_oss":
self._tiled_mlp_dist_impl = DeepSpeedTiledMLPMoE
else:
self._tiled_mlp_dist_impl = DeepSpeedTiledMLP
self._tiled_mlp_dist_impl = DeepSpeedTiledMLP
else:
self._tiled_mlp_dist_impl = TiledMLP

View File

@@ -28,6 +28,15 @@ PATCHED_EVAL_CODE = {
"array": 'metrics[f"{metric_key_prefix}_loss"] = np.nanmean(all_losses).item()',
}
ORIGINAL_FSDP2_CODE = """
model.eval()
"""
PATCHED_FSDP2_CODE = """
if hasattr(model, "eval") and callable(model.eval):
self.model.eval()
"""
ORIGINAL_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).mean().item()"
PATCHED_MAYBE_CODE = "tr_loss_scalar = self._nested_gather(tr_loss).nanmean().item()"
@@ -37,7 +46,13 @@ def check_evaluation_loop_is_patchable() -> bool:
return all(value in evaluation_loop_source for value in ORIGINAL_EVAL_CODE.values())
def patch_evaluation_loop():
def check_evaluation_loop_is_fsdp2_patchable() -> bool:
evaluation_loop_source = inspect.getsource(Trainer.evaluation_loop)
evaluation_loop_source, _ = detab_code(evaluation_loop_source)
return ORIGINAL_FSDP2_CODE in evaluation_loop_source
def patch_evaluation_loop(patch_fsdp2: bool):
"""Patch the evaluation_loop method."""
# Check if already patched
if hasattr(Trainer, "_original_evaluation_loop"):
@@ -60,6 +75,13 @@ def patch_evaluation_loop():
ORIGINAL_EVAL_CODE["array"], PATCHED_EVAL_CODE["array"]
)
# Apply FSDP2 eval guard patch if needed
if patch_fsdp2 and ORIGINAL_FSDP2_CODE in evaluation_loop_source:
evaluation_loop_source = evaluation_loop_source.replace(
ORIGINAL_FSDP2_CODE, PATCHED_FSDP2_CODE
)
LOG.info("Applied FSDP2 eval guard patch to evaluation_loop")
# Rename the function to avoid conflicts
evaluation_loop_source = evaluation_loop_source.replace(
"def evaluation_loop(",

View File

@@ -75,7 +75,7 @@ class PromptTokenizingStrategy(abc.ABC):
) -> BatchEncoding:
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
if not prompt:
LOG.warning_once("Empty text requested for tokenization.")
LOG.warning("Empty text requested for tokenization.")
return empty
result = self.tokenizer(

View File

@@ -416,9 +416,7 @@ def save_initial_configs(
# Pre-save the tokenizer and model configs
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
tokenizer.save_pretrained(
str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files
)
tokenizer.save_pretrained(str(output_dir))
if hasattr(model, "config"):
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
model.config.save_pretrained(str(output_dir))
@@ -594,9 +592,6 @@ def train(
# Save the trained model and cleanup
save_trained_model(cfg, trainer, model, safe_serialization)
tokenizer.save_pretrained(
str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files
)
create_model_card(cfg, trainer)
if not cfg.use_ray:
cleanup_distributed()

View File

@@ -60,14 +60,13 @@ def gpu_memory_usage_all(device=0):
active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1024.0**3
allocated = torch.cuda.max_memory_allocated(device) / 1024.0**3
reserved = torch.cuda.max_memory_reserved(device) / 1024.0**3
torch.cuda.reset_peak_memory_stats(device)
return active, allocated, reserved
def mps_memory_usage_all():
active = torch.mps.current_allocated_memory() / 1024.0**3
allocated = torch.mps.driver_allocated_memory() / 1024.0**3
return active, allocated, 0
usage = torch.mps.current_allocated_memory() / 1024.0**3
reserved = torch.mps.driver_allocated_memory() / 1024.0**3
return usage, reserved - usage, 0
def npu_memory_usage_all(device=0):

View File

@@ -1,64 +0,0 @@
"""A callback for calculating tokens per second during training."""
import time
import torch
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
class TokensPerSecondCallback(TrainerCallback):
"""
A callback to measure and log tokens per second during training.
"""
def __init__(self, tensor_parallel_size, context_parallel_size):
super().__init__()
self.step_time = 0.0
self.start_time = 0.0
self.non_data_parallel_size = 1
if tensor_parallel_size is not None:
self.non_data_parallel_size *= tensor_parallel_size
if context_parallel_size is not None:
self.non_data_parallel_size *= context_parallel_size
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
): # pylint: disable=unused-argument
self.start_time = time.perf_counter()
state.last_tokens_per_second = torch.zeros(1)
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
): # pylint: disable=unused-argument
if hasattr(state, "num_tokens"):
step_time = time.perf_counter() - self.start_time
num_tokens_per_device = state.num_tokens.clone()
# non data parallel groups have duplicated tokens, so we avoid double-counting
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
state.last_tokens_per_second = num_tokens_per_device / step_time
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs=None,
**kwargs,
): # pylint: disable=unused-argument
# after logging, clear the running metrics
if hasattr(state, "last_tokens_per_second"):
state.last_tokens_per_second.zero_()
state.num_tokens = torch.zeros(1)

View File

@@ -1,17 +1,11 @@
"""Shared axolotl collators for multipacking, mamba, multimodal."""
"""
shared axolotl collators for multipack, mamba, multimodal
"""
from .batching import (
from .batching import ( # noqa: F401
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
PretrainingBatchSamplerDataCollatorForSeq2Seq,
V2BatchSamplerDataCollatorForSeq2Seq,
)
from .mamba import MambaDataCollator
__all__ = [
"DataCollatorForSeq2Seq",
"BatchSamplerDataCollatorForSeq2Seq",
"V2BatchSamplerDataCollatorForSeq2Seq",
"PretrainingBatchSamplerDataCollatorForSeq2Seq",
"MambaDataCollator",
]
from .mamba import MambaDataCollator # noqa: F401

View File

@@ -77,7 +77,7 @@ def resolve_dtype(cfg):
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16 and cfg.fp16 is not False:
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
else:
@@ -273,9 +273,7 @@ def validate_config(
# Convert datasets to proper format if needed
if cfg.get("datasets"):
for idx, ds_cfg in enumerate(cfg["datasets"]):
if cfg.get("rl") in ["dpo", "ipo", "simpo"] and not isinstance(
ds_cfg, DPODataset
):
if cfg.get("rl") in ["dpo", "simpo"] and not isinstance(ds_cfg, DPODataset):
cfg["datasets"][idx] = DPODataset(**ds_cfg)
elif cfg.get("rl") == "kto" and not isinstance(ds_cfg, KTODataset):
cfg["datasets"][idx] = KTODataset(**dict(ds_cfg))

View File

@@ -48,10 +48,10 @@ def apply_sequence_parallelism(
- The original sequence length before padding.
- The number of padding tokens added.
"""
batch_size, original_seq_len = batch["input_ids"].shape
original_seq_len = batch["input_ids"].size(1)
# Update ring attention params if needed
if batch.get("position_ids") is not None and batch_size == 1:
if batch.get("position_ids") is not None:
update_ring_attn_params(position_ids=batch["position_ids"])
else:
# If position_ids aren't already in the batch, create them

View File

@@ -1,8 +1,8 @@
"""Init for `axolotl.utils.data` module."""
from axolotl.utils.data.streaming import (
encode_streaming,
wrap_streaming_dataset,
from axolotl.utils.data.pretraining import (
encode_pretraining,
wrap_pretraining_dataset,
)
from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.data.sft import (
@@ -12,8 +12,8 @@ from axolotl.utils.data.sft import (
from axolotl.utils.data.utils import md5
__all__ = [
"encode_streaming",
"wrap_streaming_dataset",
"encode_pretraining",
"wrap_pretraining_dataset",
"prepare_preference_datasets",
"get_dataset_wrapper",
"prepare_datasets",

View File

@@ -0,0 +1,292 @@
"""data handling specific to pretraining"""
import functools
from collections import defaultdict
from typing import Callable, Dict, List, Optional
import torch
from datasets import Dataset
from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
LOG = get_logger(__name__)
def encode_pretraining(
tokenizer: PreTrainedTokenizerBase,
max_tokens: int,
examples: Dict[str, List],
text_column: str = "text",
concatenate: bool = True,
) -> Dict[str, List]:
res = tokenizer(
examples[text_column],
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,
)
# Convert to PyTorch tensors
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
targets = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
if not concatenate:
return {
"input_ids": [seq.tolist() for seq in input_ids],
"labels": [seq.tolist() for seq in targets],
"attention_mask": [seq.tolist() for seq in attention_mask],
}
new_input_ids = []
new_labels = []
new_attention_mask = []
# Append EOS and PAD tokens to input_ids, and correct attention_mask
for i, _ in enumerate(input_ids):
input_ids[i] = torch.cat(
(
input_ids[i],
torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
),
dim=0,
)
targets[i] = torch.cat(
(
targets[i],
torch.tensor([tokenizer.eos_token_id, -100]),
),
dim=0,
)
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
# Concatenate tokens so that their lengths are less than max_tokens
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
for ids, labels, mask in zip(input_ids, targets, attention_mask, strict=False):
if buffer_input_ids.numel() == max_tokens:
new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
else:
buffer_input_ids = torch.cat(
(
buffer_input_ids,
torch.full(
(max_tokens - buffer_input_ids.numel(),),
tokenizer.pad_token_id,
dtype=torch.long,
),
),
dim=0,
)
buffer_labels = torch.cat(
(
buffer_labels,
torch.full(
(max_tokens - buffer_labels.numel(),),
-100,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
torch.full(
(max_tokens - buffer_attention_mask.numel(),),
0,
dtype=torch.long,
),
),
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
if buffer_input_ids.numel() > 0: # for any leftover tokens
while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
buffer_input_ids = torch.cat(
(
buffer_input_ids,
torch.full(
(max_tokens - buffer_input_ids.numel(),),
tokenizer.pad_token_id,
dtype=torch.long,
),
),
dim=0,
)
buffer_labels = torch.cat(
(
buffer_labels,
torch.full(
(max_tokens - buffer_labels.numel(),),
-100,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
torch.full(
(max_tokens - buffer_attention_mask.numel(),),
0,
dtype=torch.long,
),
),
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask)
ret = {
"input_ids": [seq.tolist() for seq in new_input_ids],
"labels": [seq.tolist() for seq in new_labels],
"attention_mask": [seq.tolist() for seq in new_attention_mask],
}
LOG.debug(len(ret["input_ids"]))
return ret
def wrap_pretraining_dataset(
dataset,
tokenizer,
cfg,
ds_wrapper_fn,
max_tokens=2048,
batch_size=1,
seed=42,
buffer_size=10_000,
):
if cfg.sample_packing:
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens,
multipack_attn=cfg.pretrain_multipack_attn,
)
encode = functools.partial(
encode_packed_pretraining,
collate_fn,
ds_wrapper_fn,
max_seq_length=max_tokens,
batch_size=batch_size,
multipack_attn=cfg.pretrain_multipack_attn,
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
else:
encode = functools.partial(
encode_pretraining,
tokenizer,
max_tokens,
text_column=cfg.pretraining_dataset[0].text_column or "text",
concatenate=cfg.pretraining_sample_concatenation is True,
)
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
else:
LOG.debug("NOT shuffling merged pretraining datasets")
# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column
# this is empty during streaming/pretraining
remove_columns = []
if dataset.features is None:
for first_row in dataset:
remove_columns = list(first_row.keys())
break
else:
remove_columns = list(dataset.features.keys())
dataset = dataset.map(
encode,
batched=True,
batch_size=buffer_size,
# input_columns="text",
remove_columns=remove_columns,
)
return dataset
def encode_packed_pretraining(
collate_fn,
ds_wrapper: Callable,
examples: Dict[str, List],
max_seq_length: int = 2048,
batch_size: int = 4,
multipack_attn: Optional[bool] = True,
) -> Dict[str, List]:
# tokenize all the examples
# rows get split with stride (overlap)
train_dataset = ds_wrapper(dataset=Dataset.from_dict(examples))[0]
train_dataset = process_pretraining_datasets_for_packing(
train_dataset,
max_seq_length,
skip_position_ids=not multipack_attn,
# FIXME using attention mask unpad/pad with trainer and packed pretraining is broken atm
# workaround by using the position id logic for now in trainer
drop_attention_mask=multipack_attn,
)
sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset),
lengths=get_dataset_lengths(train_dataset),
batch_size=1,
batch_max_len=batch_size * max_seq_length,
drop_last=True,
num_processes=1,
)
chunked_data = defaultdict(list)
for batch in sampler:
for data in batch:
features = train_dataset[data]
if "num_truncated_tokens" in features:
del features["num_truncated_tokens"]
if "num_truncated_tokens" in features:
del features["num_truncated_tokens"]
if "overflow_to_sample_mapping" in features:
del features["overflow_to_sample_mapping"]
if "labels" not in features:
features["labels"] = features["input_ids"].copy()
collated_features = collate_fn(features)
for feature in features.keys():
if feature == "length":
continue
chunked_data[feature].append(collated_features[feature].squeeze(0))
return chunked_data

View File

@@ -9,14 +9,13 @@ from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
load_dataset,
)
from transformers import PreTrainedTokenizer, ProcessorMixin
from axolotl.prompters import Prompter
from axolotl.utils.data.lock import FileLockLoader
from axolotl.utils.data.streaming import wrap_streaming_dataset
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.shared import (
create_train_validation_split,
datasets_with_name_generator,
@@ -27,6 +26,7 @@ from axolotl.utils.data.shared import (
save_preprocessed_dataset,
try_load_from_hub,
)
from axolotl.utils.data.streaming import wrap_streaming_sft_dataset
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
handle_long_seq_in_dataset,
@@ -49,6 +49,7 @@ def prepare_datasets(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None = None,
preprocess_iterable: bool = False,
) -> tuple[IterableDataset | Dataset, Dataset | None, int, list[Prompter | None]]:
"""Prepare training and evaluation datasets based on configuration.
@@ -56,20 +57,24 @@ def prepare_datasets(
cfg: Dictionary mapping `axolotl` config keys to values.
tokenizer: Tokenizer to use for processing text.
processor: Optional processor for multimodal datasets.
preprocess_iterable: Whether to use iterable preprocessing.
Returns:
Tuple of (train_dataset, eval_dataset, total_steps, prompters).
"""
if cfg.streaming or cfg.pretraining_dataset:
return _prepare_streaming_dataset(cfg, tokenizer, processor)
return _prepare_standard_dataset(cfg, tokenizer, processor)
if cfg.pretraining_dataset:
return _prepare_pretraining_dataset(
cfg, tokenizer, processor, preprocess_iterable
)
return _prepare_standard_dataset(cfg, tokenizer, processor, preprocess_iterable)
def _prepare_standard_dataset(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
) -> tuple[Dataset, Dataset | None, int, list[Prompter | None]]:
preprocess_iterable: bool,
) -> tuple[Dataset | IterableDataset, Dataset | None, int, list[Prompter | None]]:
"""Prepare standard (non-pretraining) datasets."""
def _load_datasets():
@@ -79,6 +84,7 @@ def _prepare_standard_dataset(
cfg,
split="train",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
# Overwrite eval_dataset if test data exists
@@ -88,6 +94,7 @@ def _prepare_standard_dataset(
cfg,
split="test",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
return train_dataset, eval_dataset, prompters
@@ -112,7 +119,14 @@ def _prepare_standard_dataset(
)
# Calculate total number of training steps
if cfg.max_steps:
# For streaming datasets, we must use max_steps
if isinstance(train_dataset, IterableDataset):
if not cfg.max_steps:
raise ValueError(
"When using streaming datasets, you must set max_steps in your config"
)
total_num_steps = cfg.max_steps
elif cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
)
@@ -122,40 +136,22 @@ def _prepare_standard_dataset(
return train_dataset, eval_dataset, total_num_steps, prompters
def _prepare_streaming_dataset(
def _prepare_pretraining_dataset(
cfg: DictDefault,
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
preprocess_iterable: bool,
) -> tuple[IterableDataset, Dataset | None, int, list[Prompter | None]]:
"""
Prepare dataset for streaming mode.
Prepare dataset for pretraining mode.
Note: Streaming datasets are loaded incrementally from the source.
Note: Pre-training datasets are streamed from the HuggingFace Hub.
"""
if cfg.pretraining_dataset:
dataset_config = _extract_pretraining_config(cfg)
train_dataset = _load_streaming_dataset(dataset_config, cfg, tokenizer)
elif cfg.sample_packing:
# TODO(djsaunde): Implement for multiple datasets
dataset_config = DictDefault(cfg.datasets[0])
# Extract pretraining dataset configuration
pretraining_config = _extract_pretraining_config(cfg)
# Ensure we have a split set - default to 'train' if not specified
if not hasattr(dataset_config, "split") or not dataset_config.split:
dataset_config.split = "train"
train_dataset = _load_streaming_dataset(dataset_config, cfg, tokenizer)
else:
# Use legacy loading function for non-packed streaming datasets
train_dataset, eval_dataset, prompters = _load_and_prepare_datasets(
tokenizer,
cfg,
split="train",
processor=processor,
streaming=True,
)
# Return early for non-packed streaming datasets
total_num_steps = cfg.max_steps if cfg.max_steps else -1
return train_dataset, eval_dataset, total_num_steps, prompters
# Load streaming dataset for training
train_dataset = _load_pretraining_dataset(pretraining_config, cfg, tokenizer)
# Load evaluation dataset if specified
eval_dataset = None
@@ -165,12 +161,14 @@ def _prepare_streaming_dataset(
cfg,
split="test",
processor=processor,
streaming=False,
preprocess_iterable=preprocess_iterable,
)
# For streaming, we return max_steps directly from config or -1 if not set
total_num_steps = cfg.max_steps if cfg.max_steps else -1
return train_dataset, eval_dataset, total_num_steps, []
if cfg.dataset_exact_deduplication:
LOG.info("Deduplication not available for pretrained datasets")
# For pretraining, we return max_steps directly from config
return train_dataset, eval_dataset, cfg.max_steps, []
def _extract_pretraining_config(cfg: DictDefault) -> DictDefault:
@@ -202,7 +200,7 @@ def _extract_pretraining_config(cfg: DictDefault) -> DictDefault:
)
def _load_streaming_dataset(
def _load_pretraining_dataset(
pretraining_config: DictDefault, cfg: DictDefault, tokenizer: PreTrainedTokenizer
) -> IterableDataset:
"""Load and prepare a streaming dataset for pretraining."""
@@ -237,11 +235,15 @@ def _load_streaming_dataset(
iter_dataset = iter_dataset.skip(pretraining_config["skip"])
# Wrap the dataset for pretraining
train_dataset = wrap_streaming_dataset(
train_dataset = wrap_pretraining_dataset(
iter_dataset,
tokenizer,
cfg,
dataset_wrapper_partial,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed,
buffer_size=cfg.pretrain_multipack_buffer_size or 10_000,
)
# Format for PyTorch
@@ -262,7 +264,7 @@ def _load_tokenized_prepared_datasets(
cfg: DictDefault,
split: Literal["train", "test"] = "train",
processor: ProcessorMixin | None = None,
streaming: bool = False,
preprocess_iterable: bool = False,
) -> tuple[Dataset | DatasetDict, list[Prompter | None]]:
"""Load or create tokenized and prepared datasets for training or testing.
@@ -271,7 +273,7 @@ def _load_tokenized_prepared_datasets(
cfg: Configuration object.
split: Dataset split to load ('train' or 'test').
processor: Optional processor for multimodal datasets.
streaming: Whether to use iterable preprocessing.
preprocess_iterable: Whether to use iterable preprocessing.
Returns:
Tuple of (dataset, prompters list).
@@ -302,7 +304,7 @@ def _load_tokenized_prepared_datasets(
tokenizer,
split,
processor,
streaming,
preprocess_iterable,
)
return dataset, prompters
@@ -314,7 +316,7 @@ def _load_raw_datasets(
tokenizer: PreTrainedTokenizer,
split: str,
processor: ProcessorMixin | None = None,
streaming: bool = False,
preprocess_iterable: bool = False,
) -> tuple[Dataset, list[Prompter | None]]:
"""Load, process, merge, and save raw datasets."""
LOG.info("Loading raw datasets...", main_process_only=False)
@@ -335,7 +337,7 @@ def _load_raw_datasets(
split=split,
seed=cfg.seed,
processor=processor,
streaming=streaming,
preprocess_iterable=preprocess_iterable,
)
datasets.append(dataset_wrapper)
prompters.append(dataset_prompter)
@@ -343,19 +345,23 @@ def _load_raw_datasets(
# Merge datasets
dataset = merge_datasets(datasets, cfg)
if not cfg.skip_prepare_dataset and not streaming:
if not cfg.skip_prepare_dataset:
if split == "test" and cfg.eval_sequence_len:
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
else:
dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
if cfg.sample_packing:
# Skip packing processing for streaming datasets - they handle it differently
if cfg.sample_packing and not isinstance(dataset, IterableDataset):
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
# Save the prepared dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
# Skip saving for streaming datasets as they can't be cached
if not isinstance(dataset, IterableDataset):
# Save the prepared dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
return dataset, prompters
@@ -367,19 +373,21 @@ def _load_and_process_single_dataset(
split: str,
seed: int,
processor: ProcessorMixin | None = None,
streaming: bool = False,
preprocess_iterable: bool = False,
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Load and process a single dataset based on the passed config."""
# Load the dataset
# Use streaming if enabled in config or if using iterable preprocessing
use_streaming = cfg.streaming or preprocess_iterable
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=streaming
dataset_config, cfg.hf_use_auth_token, streaming=use_streaming
)
# Parse dataset type
d_base_type, d_prompt_style = _parse_dataset_type(dataset_config.type)
# Select the appropriate split
if isinstance(dataset, (DatasetDict, IterableDatasetDict)):
if isinstance(dataset, DatasetDict):
if dataset_config.split and dataset_config.split in dataset:
dataset = dataset[dataset_config.split]
elif split in dataset:
@@ -397,16 +405,63 @@ def _load_and_process_single_dataset(
num_shards=dataset_config.shards, index=shards_idx
)
# Apply dataset wrapper
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
# For streaming datasets, we need to handle tokenization differently
if isinstance(dataset, IterableDataset):
# Use pretraining's approach for multipack streaming
if cfg.sample_packing:
# Create the dataset wrapper function once
def ds_wrapper_fn(dataset=None):
wrapped_dataset, prompter = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
return wrapped_dataset, prompter
# Use pretraining wrapper for efficient streaming SFT with packing
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
dataset_wrapper = wrap_pretraining_dataset(
dataset,
tokenizer,
cfg,
ds_wrapper_fn,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed,
buffer_size=cfg.pretrain_multipack_buffer_size,
)
else:
# Use regular streaming wrapper
dataset_wrapper = wrap_streaming_sft_dataset(
dataset,
tokenizer,
cfg,
dataset_config,
d_base_type,
d_prompt_style,
processor,
max_tokens=cfg.sequence_len,
buffer_size=10_000,
)
# For streaming, we don't have a specific prompter
dataset_prompter = None
else:
# Apply dataset wrapper for regular datasets
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
return dataset_wrapper, dataset_prompter
@@ -485,7 +540,7 @@ def _load_and_prepare_datasets(
cfg: DictDefault,
split: Literal["train", "test"] = "train",
processor: ProcessorMixin | None = None,
streaming: bool = False,
preprocess_iterable: bool = False,
) -> tuple[Dataset | None, Dataset | None, list[Prompter | None]]:
"""Load and prepare datasets with optional validation split and sharding.
@@ -494,7 +549,7 @@ def _load_and_prepare_datasets(
cfg: Configuration object.
split: Dataset split to load ('train' or 'test').
processor: Optional processor for multimodal datasets.
streaming: Whether to use iterable preprocessing.
preprocess_iterable: Whether to use iterable preprocessing.
Returns:
Tuple of (train_dataset, eval_dataset, prompters).
@@ -505,7 +560,7 @@ def _load_and_prepare_datasets(
cfg,
split=split,
processor=processor,
streaming=streaming,
preprocess_iterable=preprocess_iterable,
)
# Apply dataset sharding if configured using shared function

View File

@@ -236,9 +236,11 @@ def _load_from_local_path(
try:
return load_from_disk(dataset_config.path)
except FileNotFoundError:
load_dataset_kwargs["streaming"] = False
return load_dataset(dataset_config.path, **load_dataset_kwargs)
elif local_path.is_file():
dataset_type = get_dataset_type(dataset_config)
load_dataset_kwargs["streaming"] = False
return load_dataset(
dataset_type,
data_files=dataset_config.path,
@@ -522,7 +524,9 @@ def generate_dataset_hash_from_config(
return str(md5(config_str))
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
def merge_datasets(
datasets: list[Dataset | IterableDataset], cfg: DictDefault
) -> Dataset | IterableDataset:
"""Merge multiple datasets into one with optional shuffling.
Args:
@@ -532,6 +536,41 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
Returns:
Merged dataset.
"""
# Check if we're dealing with streaming datasets
if any(isinstance(ds, IterableDataset) for ds in datasets):
# All datasets must be streaming for merging
if not all(isinstance(ds, IterableDataset) for ds in datasets):
raise ValueError(
"Cannot mix streaming and non-streaming datasets. "
"Either all datasets must be streaming or none."
)
if len(datasets) == 1:
ds = datasets[0]
# Streaming datasets handle shuffling differently
if cfg.shuffle_merged_datasets and not cfg.curriculum_sampling:
return ds.shuffle(seed=cfg.seed, buffer_size=10_000)
return ds
# Merge streaming datasets
LOG.info("Merging streaming datasets...")
from datasets import interleave_datasets
# For streaming, we interleave datasets instead of concatenating
merged_dataset = interleave_datasets(datasets)
if cfg.shuffle_merged_datasets:
LOG.debug("Shuffling merged streaming datasets...")
if cfg.curriculum_sampling:
LOG.warning(
"Shuffling merged datasets with curriculum sampling is not recommended. "
"This will randomize the order of samples."
)
merged_dataset = merged_dataset.shuffle(seed=cfg.seed, buffer_size=10_000)
return merged_dataset
# Original logic for non-streaming datasets
if len(datasets) == 1:
ds = datasets[0]

View File

@@ -1,301 +1,150 @@
"""Data handling specific to streaming datasets."""
"""Utilities for handling streaming datasets."""
import functools
from collections import defaultdict
from typing import Callable, Dict, List, Optional
from typing import Any, Dict, List
import torch
from datasets import Dataset
import numpy as np
from datasets import Dataset, IterableDataset
from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
from axolotl.utils.trainer import add_position_ids
LOG = get_logger(__name__)
def encode_streaming(
examples: Dict[str, List],
def wrap_streaming_sft_dataset(
dataset: IterableDataset,
tokenizer: PreTrainedTokenizerBase,
max_tokens: int,
text_column: str = "text",
concatenate: bool = True,
) -> Dict[str, List]:
res = tokenizer(
examples[text_column],
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,
)
# Convert to PyTorch tensors
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
targets = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
if not concatenate:
return {
"input_ids": [seq.tolist() for seq in input_ids],
"labels": [seq.tolist() for seq in targets],
"attention_mask": [seq.tolist() for seq in attention_mask],
}
new_input_ids = []
new_labels = []
new_attention_mask = []
# Append EOS and PAD tokens to input_ids, and correct attention_mask
for i, _ in enumerate(input_ids):
input_ids[i] = torch.cat(
(
input_ids[i],
torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
),
dim=0,
)
targets[i] = torch.cat(
(
targets[i],
torch.tensor([tokenizer.eos_token_id, -100]),
),
dim=0,
)
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
# Concatenate tokens so that their lengths are less than max_tokens
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
for ids, labels, mask in zip(input_ids, targets, attention_mask, strict=False):
if buffer_input_ids.numel() == max_tokens:
new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
else:
buffer_input_ids = torch.cat(
(
buffer_input_ids,
torch.full(
(max_tokens - buffer_input_ids.numel(),),
tokenizer.pad_token_id,
dtype=torch.long,
),
),
dim=0,
)
buffer_labels = torch.cat(
(
buffer_labels,
torch.full(
(max_tokens - buffer_labels.numel(),),
-100,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
torch.full(
(max_tokens - buffer_attention_mask.numel(),),
0,
dtype=torch.long,
),
),
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
if buffer_input_ids.numel() > 0: # for any leftover tokens
while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
buffer_input_ids = torch.cat(
(
buffer_input_ids,
torch.full(
(max_tokens - buffer_input_ids.numel(),),
tokenizer.pad_token_id,
dtype=torch.long,
),
),
dim=0,
)
buffer_labels = torch.cat(
(
buffer_labels,
torch.full(
(max_tokens - buffer_labels.numel(),),
-100,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
torch.full(
(max_tokens - buffer_attention_mask.numel(),),
0,
dtype=torch.long,
),
),
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask)
ret = {
"input_ids": [seq.tolist() for seq in new_input_ids],
"labels": [seq.tolist() for seq in new_labels],
"attention_mask": [seq.tolist() for seq in new_attention_mask],
}
LOG.debug(len(ret["input_ids"]))
return ret
def wrap_streaming_dataset(
dataset,
tokenizer,
cfg,
ds_wrapper_fn,
):
if cfg.sample_packing:
# For SFT (non-pretraining) datasets, always use multipack_attn=True to ensure
# attention isolation between packed sequences
multipack_attn = (
True if not cfg.pretraining_dataset else cfg.pretrain_multipack_attn
)
dataset_config,
d_base_type: str,
d_prompt_style: str | None,
processor: Any | None,
max_tokens: int = 2048,
buffer_size: int = 10_000,
) -> IterableDataset:
"""
Wrap a streaming SFT dataset with tokenization and optional packing.
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=cfg.sequence_len,
multipack_attn=multipack_attn,
)
encode = functools.partial(
encode_packed_streaming,
collate_fn,
ds_wrapper_fn,
max_seq_length=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
multipack_attn=multipack_attn,
)
This is similar to wrap_pretraining_dataset but for SFT datasets.
# Set this to 1 so downstream data_loader doesn't try to increase the batch size
# again
cfg.micro_batch_size = 1
else:
# NOTE: This is not reachable for SFT datasets since we use the pre-existing
# loading function for non-packed streaming datasets. Refer to
# _prepare_streaming_datasets in sft.py for that code path.
text_column = (
getattr(cfg.pretraining_dataset[0], "text_column", "text") or "text"
)
encode = functools.partial(
encode_streaming,
tokenizer=tokenizer,
max_tokens=cfg.sequence_len,
text_column=text_column,
concatenate=cfg.pretraining_sample_concatenation is True,
)
Args:
dataset: The streaming dataset to wrap
tokenizer: Tokenizer to use
cfg: Configuration object
dataset_config: Dataset configuration
d_base_type: Base dataset type
d_prompt_style: Prompt style
processor: Optional processor for multimodal
max_tokens: Maximum sequence length
buffer_size: Buffer size for shuffling
Returns:
Wrapped streaming dataset ready for training
"""
# Import here to avoid circular imports
from axolotl.utils.data.wrappers import get_dataset_wrapper
# Apply shuffling if configured
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(
seed=cfg.seed, buffer_size=cfg.streaming_multipack_buffer_size
)
else:
LOG.debug("NOT shuffling merged pretraining datasets")
LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}")
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column
# this is empty during streaming/pretraining
# For streaming datasets, we need to get column names from the first sample
remove_columns = []
if dataset.features is None:
for first_row in dataset:
remove_columns = list(first_row.keys())
break
else:
remove_columns = list(dataset.features.keys())
for first_row in dataset:
remove_columns = list(first_row.keys())
break
# Reset dataset after peeking
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
# Define the encoding function - always add position_ids for compatibility
if cfg.sample_packing:
# For sample packing, we need to handle position_ids
def encode_streaming_packed(examples: Dict[str, List]) -> Dict[str, List]:
"""Encode examples for streaming with sample packing."""
# Convert the batch dict to a temporary Dataset for processing
temp_dataset = Dataset.from_dict(examples)
# Apply the dataset wrapper to tokenize
wrapped_dataset, _ = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=temp_dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
# Convert to dict for processing
result = {}
if hasattr(wrapped_dataset, "to_dict"):
result = wrapped_dataset.to_dict()
else:
for key in wrapped_dataset.column_names:
result[key] = wrapped_dataset[key]
# Add position_ids using the existing function
result = add_position_ids(result)
# For multipack attention, we may need to drop attention_mask
if cfg.pretrain_multipack_attn and "attention_mask" in result:
del result["attention_mask"]
return result
encode_fn = encode_streaming_packed
else:
# Regular encoding without packing - still add position_ids for compatibility
def encode_streaming(examples: Dict[str, List]) -> Dict[str, List]:
"""Encode examples for streaming."""
# Convert the batch dict to a temporary Dataset for processing
temp_dataset = Dataset.from_dict(examples)
# Apply the dataset wrapper to tokenize
wrapped_dataset, _ = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=temp_dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
# Convert to dict format
result = {}
if hasattr(wrapped_dataset, "to_dict"):
result = wrapped_dataset.to_dict()
else:
for key in wrapped_dataset.column_names:
result[key] = wrapped_dataset[key]
# Add position_ids even without packing for compatibility
result = add_position_ids(result)
return result
encode_fn = encode_streaming
# Map the encoding function over the streaming dataset
dataset = dataset.map(
encode,
encode_fn,
batched=True,
batch_size=cfg.streaming_multipack_buffer_size,
batch_size=buffer_size,
remove_columns=remove_columns,
)
# Set format for PyTorch
dataset = dataset.with_format("torch")
return dataset
def encode_packed_streaming(
collate_fn,
ds_wrapper: Callable,
examples: Dict[str, List],
max_seq_length: int = 2048,
batch_size: int = 4,
multipack_attn: Optional[bool] = True,
) -> Dict[str, List]:
# tokenize all the examples
# rows get split with stride (overlap)
train_dataset = ds_wrapper(dataset=Dataset.from_dict(examples))[0]
train_dataset = process_pretraining_datasets_for_packing(
train_dataset,
max_seq_length,
skip_position_ids=not multipack_attn,
# FIXME using attention mask unpad/pad with trainer and packed pretraining is broken atm
# workaround by using the position id logic for now in trainer
drop_attention_mask=multipack_attn,
)
sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset),
lengths=get_dataset_lengths(train_dataset),
batch_size=1,
batch_max_len=batch_size * max_seq_length,
drop_last=True,
num_processes=1,
)
chunked_data = defaultdict(list)
for batch in sampler:
for data in batch:
features = train_dataset[data]
if "num_truncated_tokens" in features:
del features["num_truncated_tokens"]
if "overflow_to_sample_mapping" in features:
del features["overflow_to_sample_mapping"]
if "labels" not in features:
features["labels"] = features["input_ids"].copy()
collated_features = collate_fn(features)
for feature in features.keys():
if feature == "length":
continue
chunked_data[feature].append(collated_features[feature].squeeze(0))
return chunked_data

View File

@@ -178,8 +178,8 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def handle_long_seq_in_dataset(
dataset: Dataset, sequence_len: int, cfg: DictDefault
) -> Dataset:
dataset: Dataset | IterableDataset, sequence_len: int, cfg: DictDefault
) -> Dataset | IterableDataset:
"""Remove sequences longer than configured maximum from dataset.
Args:
@@ -190,21 +190,19 @@ def handle_long_seq_in_dataset(
Returns:
Filtered dataset with long sequences removed.
"""
if (
hasattr(dataset, "column_names")
and dataset.column_names
and "input_ids" not in dataset.column_names
):
# Streaming datasets don't support filtering the same way
if isinstance(dataset, IterableDataset):
LOG.info(
"Streaming dataset detected - long sequence filtering will be done on-the-fly"
)
return dataset
if not hasattr(dataset, "column_names") or "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."
)
return dataset
elif not hasattr(dataset, "column_names") or dataset.column_names is None:
LOG.info(
"Dataset is streaming (IterableDataset), skipping long sequence handling"
)
return dataset
drop_long = functools.partial(
drop_long_seq,

View File

@@ -138,12 +138,6 @@ class AxolotlInputConfig(
"description": "Process reward modelling: `True` or `False`"
},
)
center_rewards_coefficient: float | None = Field(
default=None,
json_schema_extra={
"description": "Coefficient to incentivize the reward model to output mean-zero rewards (proposed by https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`."
},
)
num_labels: int | None = None
# Whether to use weighting in DPO trainer.
# If `None`, default is `False` in the trainer.
@@ -250,6 +244,12 @@ class AxolotlInputConfig(
dataloader_num_workers: int | None = None
dataloader_prefetch_factor: int | None = None
dataloader_drop_last: bool | None = None
streaming: bool | None = Field(
default=None,
json_schema_extra={
"description": "Enable streaming mode for training datasets to reduce memory usage and enable training on datasets larger than memory"
},
)
accelerator_config: dict[str, Any] | None = None
@@ -481,6 +481,12 @@ class AxolotlInputConfig(
},
)
multipack_real_batches: bool | None = None
pretraining_sample_concatenation: bool | None = Field(
default=None,
json_schema_extra={
"description": "whether to concatenate samples during pretraining",
},
)
batch_flattening: Literal["auto"] | bool | None = Field(
default=None,
@@ -495,34 +501,13 @@ class AxolotlInputConfig(
pose_max_context_len: int | None = None
pose_num_chunks: int | None = None
# Deprecated: Use streaming_multipack_buffer_size instead
pretrain_multipack_buffer_size: int | None = Field(
default=None,
deprecated="Deprecated in v0.13.0, will be removed in v0.14.0. Use streaming_multipack_buffer_size instead",
)
pretrain_multipack_buffer_size: int | None = 10_000
pretrain_multipack_attn: bool | None = Field(
default=True,
json_schema_extra={
"description": "whether to prevent cross attention for packed sequences during pretraining",
},
)
pretraining_sample_concatenation: bool | None = Field(
default=None,
json_schema_extra={
"description": "whether to concatenate samples during pretraining",
},
)
streaming: bool | None = Field(
default=None,
json_schema_extra={"description": "Use streaming mode for loading datasets"},
)
streaming_multipack_buffer_size: int | None = Field(
default=10_000,
json_schema_extra={
"description": "Buffer size for multipack streaming datasets"
},
)
xformers_attention: bool | None = Field(
default=None,
@@ -851,15 +836,10 @@ class AxolotlInputConfig(
include_tokens_per_second: bool | None = Field(
default=None,
json_schema_extra={
"description": "bool of whether to report tokens per second at the end of training. This is not supported with pre-training datasets."
},
)
include_tkps: bool | None = Field(
default=True,
json_schema_extra={
"description": "bool of whether to report tokens per second per-gpu during training by measuring throughput of non-padding tokens."
"description": "bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time."
},
)
neftune_noise_alpha: float | None = Field(
default=None,
json_schema_extra={
@@ -953,15 +933,7 @@ class AxolotlInputConfig(
},
)
fix_untrained_tokens: int | list[int] | None = Field(
default=None,
json_schema_extra={
"description": (
"Token index or indices to adjust embedding weights to the mean of the other tokens. "
"This is useful when the model has untrained embeddings."
)
},
)
fix_untrained_tokens: int | list[int] | None = None
# INTERNALS - document for now, generally not set externally
is_preprocess: bool | None = None
@@ -1020,26 +992,6 @@ class AxolotlInputConfig(
return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs]
return None
@model_validator(mode="before")
@classmethod
def warn_peft_trainable_token_to_fix_untrained(cls, data):
if (
peft_trainable_token_indices := data.get("peft_trainable_token_indices")
) and (fix_untrained_tokens := data.get("fix_untrained_tokens")):
if isinstance(fix_untrained_tokens, int):
fix_untrained_tokens = (fix_untrained_tokens,)
if isinstance(peft_trainable_token_indices, int):
peft_trainable_token_indices = (peft_trainable_token_indices,)
for untrained_token_id in fix_untrained_tokens:
if untrained_token_id not in peft_trainable_token_indices:
LOG.warning_once(
f"Token {untrained_token_id} is fixed via `fix_untrained_tokens`, yet not in `peft_trainable_token_indices: ` list. "
"Please add it, otherwise the token won't be trained on."
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate GPU capabilities with the configured options"""
@@ -1313,14 +1265,3 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
data["dataset_processes"] = get_default_process_count()
return data
@model_validator(mode="before")
@classmethod
def check_deduplication_with_streaming(cls, data):
if data.get("dataset_exact_deduplication") and (
data.get("streaming") or data.get("pretraining_dataset")
):
raise NotImplementedError(
"dataset_exact_deduplication is not available for streaming datasets. "
)
return data

View File

@@ -59,21 +59,16 @@ class ModelInputConfig(BaseModel):
processor_type: str | None = Field(
default=None, json_schema_extra={"description": "transformers processor class"}
)
tokenizer_save_jinja_files: bool | None = Field(
default=True, # match the default behavior from transformers
json_schema_extra={
"description": "Whether to save jinja files for tokenizer, transformers default is True"
},
)
trust_remote_code: bool | None = Field(
default=None,
json_schema_extra={"description": "Trust remote code for untrusted source"},
)
experimental_skip_move_to_device: bool | None = Field(
default=True,
default=None,
json_schema_extra={
"description": "Don't move the model to the device before sharding. Set to `false` to revert to legacy behavior."
"description": "Don't move the model to the device before sharding. "
"This is an experimental feature that may be included in the future as the default."
},
)

View File

@@ -90,16 +90,6 @@ class LoraConfig(BaseModel):
"description": "How to initialize LoRA weights. Default to True which is MS original implementation."
},
)
peft_trainable_token_indices: list[int] | dict[str, list[int]] | None = Field(
default=None,
json_schema_extra={
"description": (
"A list of token indices to fine-tune on the `embed_tokens` layer.\n"
"Otherwise, a dict mapping an embedding layer name to its trainable token indices.\n"
"See https://huggingface.co/docs/peft/v0.17.0/en/developer_guides/lora#efficiently-train-tokens-alongside-lora"
)
},
)
qlora_sharded_model_loading: bool | None = Field(
default=False,

View File

@@ -60,20 +60,6 @@ class DatasetValidationMixin:
raise ValueError("either datasets or pretraining_dataset is required")
return data
@model_validator(mode="before")
@classmethod
def check_pretraining_streaming_deprecation(cls, data):
# TODO(djsaunde): remove this check + implement change for 0.13.0 release
if data.get("pretraining_dataset") and not data.get("streaming"):
LOG.warning(
"Setting `pretraining_dataset` without explicitly setting `streaming: "
"true` is deprecated. In a future release, streaming will not be "
"automatically enabled when using pretraining_dataset. Please "
"explicitly set `streaming: true` in your configuration to maintain "
"current behavior."
)
return data
@model_validator(mode="before")
@classmethod
def check_push_ds_auth(cls, data):
@@ -354,30 +340,6 @@ class TrainingValidationMixin:
)
return data
@model_validator(mode="before")
@classmethod
def check_multipack_buffer_size(cls, data):
if data.get("pretrain_multipack_buffer_size") and not data.get(
"streaming_multipack_buffer_size"
):
LOG.warning(
"`pretrain_multipack_buffer_size` is deprecated in v0.13.0, will be "
"removed in v0.14.0. Use `streaming_multipack_buffer_size` instead."
)
data["streaming_multipack_buffer_size"] = data[
"pretrain_multipack_buffer_size"
]
del data["pretrain_multipack_buffer_size"]
elif data.get("pretrain_multipack_buffer_size") and data.get(
"streaming_multipack_buffer_size"
):
raise ValueError(
"pretrain_multipack_buffer_size is deprecated, use "
"streaming_multipack_buffer_size; both are set, please remove the "
"deprecated pretrain_multipack_buffer_size setting"
)
return data
@model_validator(mode="after")
def check_fft_possible_bad_config(self):
if (
@@ -1114,46 +1076,20 @@ class PretrainingValidationMixin:
@model_validator(mode="before")
@classmethod
def check_pretraining_w_val_set_size(cls, data):
if data.get("pretraining_dataset") and data.get("val_set_size"):
raise ValueError(
"val_set_size is not supported with pretraining_dataset. "
"Use test_datasets to specify evaluation datasets for pretraining."
)
return data
@model_validator(mode="before")
@classmethod
def check_streaming_w_val_set_size(cls, data):
if data.get("streaming") and data.get("val_set_size"):
raise ValueError(
"val_set_size is not supported with streaming datasets. "
"Use test_datasets to specify evaluation datasets when streaming is enabled."
)
return data
@model_validator(mode="before")
@classmethod
def check_streaming_w_max_steps(cls, data):
if data.get("streaming") and not data.get("max_steps"):
raise ValueError(
"max_steps must be set when using streaming datasets. "
"Trainer cannot infer dataset length for iterable datasets."
)
return data
@model_validator(mode="before")
@classmethod
def check_streaming_w_multiple_datasets(cls, data):
if (
data.get("streaming")
and data.get("sample_packing")
and data.get("datasets")
and len(data.get("datasets")) > 1
):
raise NotImplementedError(
"Sample packing with multiple streaming datasets is not yet supported"
)
def check_pretraining_split_batches_accelerate(cls, data):
# alternatively set ACCELERATE_SPLIT_BATCHES=False
if data.get("streaming"):
accelerator_config = data.get("accelerator_config", {})
if not accelerator_config:
data["accelerator_config"] = {
"split_batches": False,
"dispatch_batches": False,
}
else:
if accelerator_config.get("split_batches") is None:
data["accelerator_config"]["split_batches"] = False
if accelerator_config.get("dispatch_batches") is None:
data["accelerator_config"]["dispatch_batches"] = False
return data

View File

@@ -475,9 +475,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
train_dataset.remove_columns(["length"]),
batch_sampler=sampler,
)
data_loader_len = max(
1, len(data_loader) * cfg.micro_batch_size // cfg.batch_size
)
data_loader_len = len(data_loader) * cfg.micro_batch_size // cfg.batch_size
LOG.debug(f"data_loader_len: {data_loader_len}")
# FIXME: is there a bug here somewhere? the total num steps depends
# on the agreed on value for sample_packing_eff_est
@@ -549,13 +547,6 @@ def setup_deepspeed_env(cfg, stage=None):
if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
device_count = torch.cuda.device_count()
if device_count == 1:
os.environ.setdefault("WORLD_SIZE", "1")
os.environ.setdefault("LOCAL_RANK", "0")
os.environ.setdefault("MASTER_ADDR", "0.0.0.0") # nosec B104
os.environ.setdefault("MASTER_PORT", "29500")
# NOTE(djsaunde): The distribued state cannot be initialized prior to the
# ACCELERATE_USE_DEEPSPEED assignment, but it must be initialized some time prior
# to model load.

View File

@@ -25,7 +25,7 @@ def min_cfg(temp_dir):
"liger_rms_norm": True,
"liger_glu_activation": True,
"torch_compile": True,
"chat_template": "qwen3",
"chat_template": "llama3",
"kd_trainer": True,
"kd_ce_alpha": 0.1,
"kd_alpha": 0.9,

View File

@@ -1,73 +0,0 @@
"""E2E tests for streaming dataset functionality"""
# pylint: disable=duplicate-code
import pytest
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard
class TestStreamingDatasets:
"""Test case for streaming datasets"""
@pytest.mark.parametrize(
"sample_packing",
[True, False],
)
def test_streaming_dataset(self, temp_dir, sample_packing):
"""Test streaming datasets"""
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"flash_attention": True,
"sequence_len": 1024,
"sample_packing": sample_packing,
"pretrain_multipack_attn": sample_packing,
"streaming_multipack_buffer_size": 10000,
"dataset_processes": 1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
# Streaming config
"streaming": True,
"max_steps": 3,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"val_set_size": 0.0,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"use_tensorboard": True,
"save_first_step": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
# Verify training actually happened by checking loss decrease
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
3.0,
"Train Loss (%s) is too high",
)

View File

@@ -1,63 +0,0 @@
"""
e2e test for saving the tokenizer
"""
from unittest.mock import patch
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists
def test_tokenizer_no_save_jinja_files(temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"chat_template": "chatml",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_first_step": False,
"fp16": False,
"tokenizer_save_jinja_files": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
with patch("axolotl.train.execute_training"):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
with open(f"{temp_dir}/tokenizer_config.json", "r", encoding="utf-8") as f:
tokenizer_config = f.read()
assert "chat_template" in tokenizer_config

View File

@@ -3,6 +3,7 @@
import unittest
from axolotl.monkeypatch.transformers.trainer_loss_calc import (
check_evaluation_loop_is_fsdp2_patchable,
check_evaluation_loop_is_patchable,
check_maybe_log_save_evaluate_is_patchable,
)
@@ -19,6 +20,7 @@ class TestTrainerLossCalc(unittest.TestCase):
the patched code changes upstream.
"""
assert check_evaluation_loop_is_patchable()
assert check_evaluation_loop_is_fsdp2_patchable()
assert check_maybe_log_save_evaluate_is_patchable()

View File

@@ -6,7 +6,7 @@ import unittest
from transformers import LlamaTokenizer
from axolotl.utils.data import encode_streaming, md5
from axolotl.utils.data import encode_pretraining, md5
from tests.hf_offline_utils import enable_hf_offline
@@ -39,7 +39,7 @@ class TestEncodePretraining(unittest.TestCase):
"hello, hello",
]
}
result = encode_streaming(examples, self.tokenizer, self.max_tokens)
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
self.assertEqual(len(result["input_ids"]), 3)

View File

@@ -1,11 +1,16 @@
"""Module for testing dataset sequence packing"""
import unittest
from pathlib import Path
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy
from axolotl.prompters import AlpacaPrompter
from axolotl.train import setup_model_and_trainer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
@@ -30,6 +35,43 @@ class TestPacking(unittest.TestCase):
}
)
def test_increments_attention(self):
prompter = AlpacaPrompter("chat")
strat = AlpacaPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
dateset = load_dataset(
"json",
data_files=str(Path(__file__).parent / "fixtures/alpaca/alpaca.json"),
)["train"]
dataset = Dataset.from_list(list(TokenizedPromptDataset(strat, dateset)))
constant_len_dataset = ConstantLengthDataset(
self.tokenizer,
[dataset],
seq_length=2048,
)
packed_dataset = Dataset.from_list(list(constant_len_dataset))
example = packed_dataset[0]
next_bos_index = (
example["input_ids"][1:].index(self.tokenizer.bos_token_id) + 1
) # add one since we sliced
# first example doesn't have mask reset
assert example["input_ids"][0] == self.tokenizer.bos_token_id
assert example["attention_mask"][0] == 1
assert example["position_ids"][0] == 0
assert example["position_ids"][1] == 1
# but subsequent one does
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
assert example["attention_mask"][next_bos_index] == 2
assert example["position_ids"][next_bos_index] == 0
assert example["position_ids"][next_bos_index + 1] == 1
@with_temp_dir
def test_lora_packing(self, temp_dir):
cfg = DictDefault(

View File

@@ -9,7 +9,7 @@ import torch
from datasets import IterableDataset
from torch.utils.data import DataLoader
from axolotl.utils.data import get_dataset_wrapper, wrap_streaming_dataset
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
from axolotl.utils.dict import DictDefault
@@ -77,11 +77,14 @@ class TestPretrainingPacking:
)
original_bsz = cfg.micro_batch_size
train_dataset = wrap_streaming_dataset(
train_dataset = wrap_pretraining_dataset(
dataset,
tokenizer_huggyllama,
cfg,
ds_wrapper_partial,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed or 42,
)
trainer_loader = DataLoader(

View File

@@ -1,238 +0,0 @@
"""Test streaming configuration and data loading functionality."""
import unittest
from unittest.mock import Mock, patch
from datasets import IterableDataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.data.sft import (
_prepare_streaming_dataset,
prepare_datasets,
)
from axolotl.utils.config import validate_config
class TestStreamingConfig(unittest.TestCase):
"""Test streaming configuration and deprecation handling."""
def test_streaming_multipack_buffer_size_deprecation(self):
"""Test that pretrain_multipack_buffer_size is properly deprecated."""
# Test with old config name
cfg_old = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"pretrain_multipack_buffer_size": 5000,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
"sequence_len": 256,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 0.0001,
}
)
with self.assertLogs("axolotl.utils.schemas.validation", level="WARNING") as cm:
validated_cfg = validate_config(cfg_old)
self.assertIn("pretrain_multipack_buffer_size` is deprecated", cm.output[0])
self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 5000)
self.assertIsNone(
getattr(validated_cfg, "pretrain_multipack_buffer_size", None)
)
def test_streaming_multipack_buffer_size_new(self):
"""Test that new streaming_multipack_buffer_size works correctly."""
cfg_new = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"streaming_multipack_buffer_size": 7000,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
"sequence_len": 256,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 0.0001,
}
)
validated_cfg = validate_config(cfg_new)
self.assertEqual(validated_cfg.streaming_multipack_buffer_size, 7000)
def test_both_buffer_sizes_raises_error(self):
"""Test that having both old and new buffer size configs raises an error."""
cfg_both = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"pretrain_multipack_buffer_size": 5000,
"streaming_multipack_buffer_size": 7000,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
"sequence_len": 256,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"learning_rate": 0.0001,
}
)
with self.assertRaises(ValueError) as cm:
validate_config(cfg_both)
self.assertIn("both are set", str(cm.exception))
class TestStreamingDatasetPreparation(unittest.TestCase):
"""Test dataset preparation with streaming configuration."""
def setUp(self):
self.tokenizer = Mock()
self.tokenizer.pad_token_id = 0
self.tokenizer.eos_token_id = 1
@patch("axolotl.utils.data.sft._prepare_streaming_dataset")
def test_prepare_datasets_with_streaming_true(self, mock_prepare_streaming):
"""Test that streaming=True triggers streaming dataset preparation."""
cfg = DictDefault(
{
"streaming": True,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
}
)
mock_prepare_streaming.return_value = (Mock(), None, 100, [])
prepare_datasets(cfg, self.tokenizer)
mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None)
@patch("axolotl.utils.data.sft._prepare_streaming_dataset")
def test_prepare_datasets_with_pretraining_dataset(self, mock_prepare_streaming):
"""Test that pretraining_dataset triggers streaming dataset preparation."""
cfg = DictDefault(
{
"pretraining_dataset": "test/dataset",
}
)
mock_prepare_streaming.return_value = (Mock(), None, 100, [])
prepare_datasets(cfg, self.tokenizer)
mock_prepare_streaming.assert_called_once_with(cfg, self.tokenizer, None)
@patch("axolotl.utils.data.sft._prepare_standard_dataset")
def test_prepare_datasets_without_streaming(self, mock_prepare_standard):
"""Test that without streaming, standard dataset preparation is used."""
cfg = DictDefault(
{
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
}
)
mock_prepare_standard.return_value = (Mock(), None, 100, [])
prepare_datasets(cfg, self.tokenizer)
mock_prepare_standard.assert_called_once_with(cfg, self.tokenizer, None)
class TestStreamingWithSamplePacking(unittest.TestCase):
"""Test streaming dataset preparation with sample packing."""
def setUp(self):
self.tokenizer = Mock()
self.tokenizer.pad_token_id = 0
self.tokenizer.eos_token_id = 1
@patch("axolotl.utils.data.sft._load_streaming_dataset")
def test_streaming_sft_with_sample_packing_sets_split(self, mock_load_streaming):
"""Test that streaming SFT with sample_packing sets default split."""
cfg = DictDefault(
{
"streaming": True,
"sample_packing": True,
"datasets": [{"path": "test/dataset", "type": "alpaca"}],
"sequence_len": 256,
"micro_batch_size": 1,
}
)
mock_load_streaming.return_value = Mock(spec=IterableDataset)
with patch("axolotl.utils.data.sft._load_and_prepare_datasets"):
_prepare_streaming_dataset(cfg, self.tokenizer, None)
# Check that the dataset config has split set to 'train'
call_args = mock_load_streaming.call_args
dataset_config = call_args[0][0]
self.assertEqual(dataset_config.split, "train")
def test_multipack_attn_forced_true_for_sft(self):
"""Test that multipack_attn is forced to True for SFT with sample packing."""
from axolotl.utils.data.streaming import wrap_streaming_dataset
cfg = DictDefault(
{
"sample_packing": True,
"pretrain_multipack_attn": False, # Should be overridden for SFT
"pretraining_dataset": None, # This makes it SFT
"sequence_len": 256,
"micro_batch_size": 1,
"streaming_multipack_buffer_size": 1000,
"seed": 42,
}
)
mock_dataset = Mock()
mock_dataset.features = None # For streaming datasets
mock_dataset.__iter__ = Mock(return_value=iter([])) # Empty iterator
mock_dataset.map = Mock(return_value=mock_dataset)
mock_ds_wrapper = Mock()
with patch(
"axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq"
) as mock_collator:
with patch("axolotl.utils.data.streaming.encode_packed_streaming"):
wrap_streaming_dataset(
mock_dataset, self.tokenizer, cfg, mock_ds_wrapper
)
# Check that multipack_attn=True was used in the collator
mock_collator.assert_called_once()
call_kwargs = mock_collator.call_args[1]
self.assertTrue(call_kwargs["multipack_attn"])
def test_multipack_attn_respects_config_for_pretraining(self):
"""Test that multipack_attn respects config for pretraining datasets."""
from axolotl.utils.data.streaming import wrap_streaming_dataset
cfg = DictDefault(
{
"sample_packing": True,
"pretrain_multipack_attn": False, # Should be respected for pretraining
"pretraining_dataset": "test/dataset", # This makes it pretraining
"sequence_len": 256,
"micro_batch_size": 1,
"streaming_multipack_buffer_size": 1000,
"seed": 42,
}
)
mock_dataset = Mock()
mock_dataset.features = None # For streaming datasets
mock_dataset.__iter__ = Mock(return_value=iter([])) # Empty iterator
mock_dataset.map = Mock(return_value=mock_dataset)
mock_ds_wrapper = Mock()
with patch(
"axolotl.utils.data.streaming.PretrainingBatchSamplerDataCollatorForSeq2Seq"
) as mock_collator:
with patch("axolotl.utils.data.streaming.encode_packed_streaming"):
wrap_streaming_dataset(
mock_dataset, self.tokenizer, cfg, mock_ds_wrapper
)
# Check that multipack_attn=False was used (respecting config)
mock_collator.assert_called_once()
call_kwargs = mock_collator.call_args[1]
self.assertFalse(call_kwargs["multipack_attn"])
if __name__ == "__main__":
unittest.main()