Compare commits

..

8 Commits

Author SHA1 Message Date
NanoCode012
e37a768960 feat: add baseten to lmeval 2025-08-29 18:02:26 +07:00
Wing Lian
6afba3871d Add support for PyTorch 2.8.0 (#3106)
* Add support for PyTorch 2.8.0

* loosen triton requirements

* handle torch 2.8.0 in setup.py

* fix versions

* no vllm for torch 2.8.0

* remove comment

Co-authored-by: NanoCode012 <nano@axolotl.ai>

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-08-28 09:10:40 -04:00
Dan Saunders
dc338c3b0e Update .coderabbit.yaml (#3109) [skip ci]
Oops, should be false.
2025-08-27 09:50:52 -04:00
salman
d0d2fc5606 Tokens per second logging [skip-e2e] (#3072) 2025-08-27 09:10:14 +01:00
Wing Lian
e1131e9619 make always skip_move_to_device default as true (#3084) 2025-08-26 09:30:22 -04:00
Wing Lian
c4c4b90638 add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json (#3093)
* add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json

* fix test import
2025-08-26 09:30:04 -04:00
Wing Lian
0e9945e3b9 deploy training jobs to baseten w truss in axolotl cli (#3086) [skip ci]
* deploy training jobs to baseten w truss in axolotl cli

* cleanup
2025-08-26 09:29:50 -04:00
NanoCode012
0de254a0d0 feat: add gemma3_text attention handling for lora kernels (#3103) 2025-08-26 16:47:26 +07:00
36 changed files with 557 additions and 349 deletions

View File

@@ -12,6 +12,6 @@ reviews:
auto_review:
enabled: true
drafts: false
auto_incremental_review: true
auto_incremental_review: false
chat:
auto_reply: true

View File

@@ -36,6 +36,11 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -110,6 +115,11 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -169,6 +179,12 @@ jobs:
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
is_latest:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -33,13 +33,6 @@ jobs:
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -47,6 +40,13 @@ jobs:
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:

View File

@@ -55,7 +55,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
timeout-minutes: 20
steps:
@@ -130,7 +130,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
timeout-minutes: 20
steps:
@@ -240,7 +240,7 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
@@ -298,6 +298,12 @@ jobs:
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -334,10 +340,10 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
steps:

View File

@@ -0,0 +1,10 @@
provider: baseten
project_name:
secrets:
- HF_TOKEN
- WANDB_API_KEY
gpu: h100
gpu_count: 8
node_count: 1

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5\""
]
},
{

View File

@@ -2,8 +2,7 @@
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.47.0
# triton 3.4.0 is not compatible with CCE
triton>=3.0.0,<3.4.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"'
)

View File

@@ -64,7 +64,9 @@ def parse_requirements(extras_require_map):
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 7):
if (major, minor) >= (2, 8):
pass
elif (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.30")

View File

@@ -7,6 +7,8 @@ from typing import Literal
import yaml
from axolotl.cli.cloud.base import Cloud
from axolotl.cli.cloud.baseten import BasetenCloud
from axolotl.cli.cloud.modal_ import ModalCloud
from axolotl.utils.dict import DictDefault
@@ -38,8 +40,15 @@ def do_cli_train(
cwd=None,
**kwargs,
) -> None:
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
cloud_cfg: DictDefault = load_cloud_cfg(cloud_config)
provider = cloud_cfg.provider or "modal"
cloud: Cloud | None
if provider == "modal":
cloud = ModalCloud(cloud_cfg)
elif provider == "baseten":
cloud = BasetenCloud(cloud_cfg.to_dict())
else:
raise ValueError(f"Unsupported cloud provider: {provider}")
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
local_dirs = {}
@@ -58,8 +67,16 @@ def do_cli_lm_eval(
cloud_config: Path | str,
config: Path | str,
) -> None:
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
cloud_cfg: DictDefault = load_cloud_cfg(cloud_config)
provider = cloud_cfg.provider or "modal"
cloud: Cloud | None
if provider == "modal":
cloud = ModalCloud(cloud_cfg)
elif provider == "baseten":
cloud = BasetenCloud(cloud_cfg.to_dict())
else:
raise ValueError(f"Unsupported cloud provider: {provider}")
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.lm_eval(config_yaml)

View File

@@ -0,0 +1,68 @@
"""Baseten Cloud CLI"""
import shutil
import subprocess # nosec B404
import tempfile
from os.path import dirname
from typing import Literal
import yaml
from axolotl.cli.cloud.base import Cloud
class BasetenCloud(Cloud):
"""Baseten Cloud Axolotl CLI"""
def __init__(self, config: dict):
self.config = config
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
raise NotImplementedError(
"Separate preprocess function for Baseten is not "
"implemented and will happen during hte train step."
)
def train(
self,
config_yaml: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
local_dirs: dict[str, str] | None = None, # pylint: disable=unused-argument
**kwargs,
):
with tempfile.TemporaryDirectory() as tmp_dir:
config = self.config.copy()
config["launcher"] = launcher
config["launcher_args"] = launcher_args
with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout:
yaml.dump(config, cloud_fout)
with open(tmp_dir + "/train.yaml", "w", encoding="utf-8") as config_fout:
config_fout.write(config_yaml)
shutil.copyfile(dirname(__file__) + "/template/run.sh", tmp_dir + "/run.sh")
shutil.copyfile(
dirname(__file__) + "/template/train_sft.py", tmp_dir + "/train_sft.py"
)
subprocess.run( # nosec B603 B607
["truss", "train", "push", "train_sft.py"], cwd=tmp_dir, check=False
)
def lm_eval(
self,
config_yaml: str,
):
with tempfile.TemporaryDirectory() as tmp_dir:
config = self.config.copy()
with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout:
yaml.dump(config, cloud_fout)
with open(tmp_dir + "/eval.yaml", "w", encoding="utf-8") as config_fout:
config_fout.write(config_yaml)
shutil.copyfile(
dirname(__file__) + "/template/eval.sh", tmp_dir + "/eval.sh"
)
shutil.copyfile(
dirname(__file__) + "/template/eval_sft.py", tmp_dir + "/eval_sft.py"
)
subprocess.run( # nosec B603 B607
["truss", "train", "push", "eval_sft.py"], cwd=tmp_dir, check=False
)

View File

@@ -0,0 +1,8 @@
#!/bin/bash
set -eux
export NCCL_SOCKET_IFNAME="^docker0,lo"
export NCCL_IB_DISABLE=0
export NCCL_TIMEOUT=1800000
axolotl lm-eval eval.yaml

View File

@@ -0,0 +1,81 @@
"""
Baseten Training Script for Axolotl
"""
# pylint: skip-file
import yaml
from truss.base import truss_config
# Import necessary classes from the Baseten Training SDK
from truss_train import definitions
cloud_config = yaml.safe_load(open("cloud.yaml", "r"))
gpu = cloud_config.get("gpu", "h100")
gpu_count = (
1 # int(cloud_config.get("gpu_count", 1)) # only single GPU supported at the moment
)
node_count = (
1 # int(cloud_config.get("node_count", 1)) # only single node support for lmeval
)
project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project"
secrets = cloud_config.get("secrets", [])
# launcher = cloud_config.get("launcher", "accelerate")
# launcher_args = cloud_config.get("launcher_args", [])
script_name = "eval.sh"
# launcher_args_str = ""
# if launcher_args:
# launcher_args_str = "-- " + " ".join(launcher_args)
# 1. Define a base image for your training job
# must use torch 2.7.0 for vllm
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
# 2. Define the Runtime Environment for the Training Job
# This includes start commands and environment variables.a
# Secrets from the baseten workspace like API keys are referenced using
# `SecretReference`.
env_vars = {
# "AXOLOTL_LAUNCHER": launcher,
# "AXOLOTL_LAUNCHER_ARGS": launcher_args_str,
}
for secret_name in secrets:
env_vars[secret_name] = definitions.SecretReference(name=secret_name)
training_runtime = definitions.Runtime(
start_commands=[ # Example: list of commands to run your training script
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
],
environment_variables=env_vars,
cache_config=definitions.CacheConfig(
enabled=True,
),
checkpointing_config=definitions.CheckpointingConfig(
enabled=True,
),
)
# 3. Define the Compute Resources for the Training Job
training_compute = definitions.Compute(
node_count=node_count,
accelerator=truss_config.AcceleratorSpec(
accelerator=truss_config.Accelerator.H100,
count=gpu_count,
),
)
# 4. Define the Training Job
# This brings together the image, compute, and runtime configurations.
my_training_job = definitions.TrainingJob(
image=definitions.Image(base_image=BASE_IMAGE),
compute=training_compute,
runtime=training_runtime,
)
# This config will be pushed using the Truss CLI.
# The association of the job to the project happens at the time of push.
first_project_with_job = definitions.TrainingProject(
name=project_name, job=my_training_job
)

View File

@@ -0,0 +1,9 @@
#!/bin/bash
set -eux
export NCCL_SOCKET_IFNAME="^docker0,lo"
export NCCL_IB_DISABLE=0
export NCCL_TIMEOUT=1800000
axolotl preprocess train.yaml
axolotl train train.yaml --launcher ${AXOLOTL_LAUNCHER} ${AXOLOTL_LAUNCHER_ARGS}

View File

@@ -0,0 +1,77 @@
"""
Baseten Training Script for Axolotl
"""
# pylint: skip-file
import yaml
from truss.base import truss_config
# Import necessary classes from the Baseten Training SDK
from truss_train import definitions
cloud_config = yaml.safe_load(open("cloud.yaml", "r"))
gpu = cloud_config.get("gpu", "h100")
gpu_count = int(cloud_config.get("gpu_count", 1))
node_count = int(cloud_config.get("node_count", 1))
project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project"
secrets = cloud_config.get("secrets", [])
launcher = cloud_config.get("launcher", "accelerate")
launcher_args = cloud_config.get("launcher_args", [])
script_name = "run.sh"
launcher_args_str = ""
if launcher_args:
launcher_args_str = "-- " + " ".join(launcher_args)
# 1. Define a base image for your training job
# must use torch 2.7.0 for vllm
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
# 2. Define the Runtime Environment for the Training Job
# This includes start commands and environment variables.a
# Secrets from the baseten workspace like API keys are referenced using
# `SecretReference`.
env_vars = {
"AXOLOTL_LAUNCHER": launcher,
"AXOLOTL_LAUNCHER_ARGS": launcher_args_str,
}
for secret_name in secrets:
env_vars[secret_name] = definitions.SecretReference(name=secret_name)
training_runtime = definitions.Runtime(
start_commands=[ # Example: list of commands to run your training script
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
],
environment_variables=env_vars,
cache_config=definitions.CacheConfig(
enabled=True,
),
checkpointing_config=definitions.CheckpointingConfig(
enabled=True,
),
)
# 3. Define the Compute Resources for the Training Job
training_compute = definitions.Compute(
node_count=node_count,
accelerator=truss_config.AcceleratorSpec(
accelerator=truss_config.Accelerator.H100,
count=gpu_count,
),
)
# 4. Define the Training Job
# This brings together the image, compute, and runtime configurations.
my_training_job = definitions.TrainingJob(
image=definitions.Image(base_image=BASE_IMAGE),
compute=training_compute,
runtime=training_runtime,
)
# This config will be pushed using the Truss CLI.
# The association of the job to the project happens at the time of push.
first_project_with_job = definitions.TrainingProject(
name=project_name, job=my_training_job
)

View File

@@ -43,7 +43,10 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
tokenizer.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
if processor:
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))

View File

@@ -84,5 +84,6 @@ def do_quantize(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")

View File

@@ -24,9 +24,7 @@ from pathlib import Path
from typing import Any
import torch
from transformers import (
TrainerCallback,
)
from transformers import TrainerCallback
from transformers.trainer_pt_utils import AcceleratorConfig
from axolotl.integrations.base import PluginManager
@@ -38,6 +36,7 @@ from axolotl.utils.callbacks import (
SaveModelOnFirstStepCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
from axolotl.utils.distributed import build_parallelism_config
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
@@ -146,6 +145,12 @@ class TrainerBuilderBase(abc.ABC):
profiler_steps_start=self.cfg.profiler_steps_start,
)
)
if self.cfg.include_tkps:
callbacks.append(
TokensPerSecondCallback(
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
)
)
return callbacks
@@ -512,6 +517,7 @@ class TrainerBuilderBase(abc.ABC):
self.cfg.eval_batch_size
)
training_args_kwargs["include_tkps"] = self.cfg.include_tkps
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs

View File

@@ -404,6 +404,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
# if the trainer has the `axolotl_cfg` property, set it
if hasattr(trainer, "axolotl_cfg"):
trainer.axolotl_cfg = self.cfg
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)

View File

@@ -42,6 +42,7 @@ from axolotl.core.trainers.utils import (
)
from axolotl.utils import get_not_null
from axolotl.utils.bench import get_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -63,6 +64,15 @@ class AxolotlTrainer(
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
_axolotl_cfg: DictDefault | None = None
@property
def axolotl_cfg(self):
return self._axolotl_cfg
@axolotl_cfg.setter
def axolotl_cfg(self, cfg):
self._axolotl_cfg = cfg
def __init__(
self,
@@ -78,7 +88,6 @@ class AxolotlTrainer(
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
@@ -327,6 +336,17 @@ class AxolotlTrainer(
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
# track number of tokens for tokens per second calculation
if self.args.include_tkps:
inputs_key = "labels" if "labels" in inputs else "input_ids"
if hasattr(self.state, "num_tokens"):
self.state.num_tokens = (
self.state.num_tokens + (inputs[inputs_key] != -100).sum()
)
else:
self.state.num_tokens = (inputs[inputs_key] != -100).sum()
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
@@ -526,9 +546,6 @@ class AxolotlTrainer(
super().create_accelerator_and_postprocess()
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
if self.is_fsdp_enabled:
if (
"limit_all_gathers" in self.args.fsdp_config
@@ -576,12 +593,19 @@ class AxolotlTrainer(
# Add memory usage
try:
active, allocated, reserved = get_gpu_memory_usage()
logs["memory/max_mem_active(gib)"] = round(active, 2)
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
logs["memory/max_active (GiB)"] = round(active, 2)
logs["memory/max_allocated (GiB)"] = round(allocated, 2)
logs["memory/device_reserved (GiB)"] = round(reserved, 2)
except (ValueError, TypeError, FileNotFoundError):
pass
if self.args.include_tkps and train_eval == "train":
# each rank will log its own tokens per second
# for logging_steps > 1 we obtain a moving average of this metric
logs["tokens_per_second_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
del self._stored_metrics[train_eval]
return super().log(logs, start_time)
@@ -657,6 +681,11 @@ class AxolotlTrainer(
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
self.data_collator.tokenizer.save_pretrained(output_dir)
save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -49,6 +49,12 @@ class AxolotlTrainingMixins:
default=False,
metadata={"help": "Use real batches for efficient training."},
)
include_tkps: bool = field(
default=True,
metadata={
"help": "Whether to include tokens per second in the training metrics."
},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"
```
## Usage

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"`'
)

View File

@@ -149,6 +149,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
return MistralAttention
if model_type == "gemma3_text":
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention
return Gemma3Attention
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"

View File

@@ -416,7 +416,9 @@ def save_initial_configs(
# Pre-save the tokenizer and model configs
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
tokenizer.save_pretrained(str(output_dir))
tokenizer.save_pretrained(
str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files
)
if hasattr(model, "config"):
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
model.config.save_pretrained(str(output_dir))
@@ -592,6 +594,9 @@ def train(
# Save the trained model and cleanup
save_trained_model(cfg, trainer, model, safe_serialization)
tokenizer.save_pretrained(
str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files
)
create_model_card(cfg, trainer)
if not cfg.use_ray:
cleanup_distributed()

View File

@@ -60,13 +60,14 @@ def gpu_memory_usage_all(device=0):
active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1024.0**3
allocated = torch.cuda.max_memory_allocated(device) / 1024.0**3
reserved = torch.cuda.max_memory_reserved(device) / 1024.0**3
torch.cuda.reset_peak_memory_stats(device)
return active, allocated, reserved
def mps_memory_usage_all():
usage = torch.mps.current_allocated_memory() / 1024.0**3
reserved = torch.mps.driver_allocated_memory() / 1024.0**3
return usage, reserved - usage, 0
active = torch.mps.current_allocated_memory() / 1024.0**3
allocated = torch.mps.driver_allocated_memory() / 1024.0**3
return active, allocated, 0
def npu_memory_usage_all(device=0):

View File

@@ -0,0 +1,62 @@
"""A callback for calculating tokens per second during training."""
import time
import torch
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
class TokensPerSecondCallback(TrainerCallback):
"""
A callback to measure and log tokens per second during training.
"""
def __init__(self, tensor_parallel_size, context_parallel_size):
super().__init__()
self.step_time = 0.0
self.start_time = 0.0
self.non_data_parallel_size = 1
if tensor_parallel_size is not None:
self.non_data_parallel_size *= tensor_parallel_size
if context_parallel_size is not None:
self.non_data_parallel_size *= context_parallel_size
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
): # pylint: disable=unused-argument
self.start_time = time.perf_counter()
state.last_tokens_per_second = torch.zeros(1)
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
): # pylint: disable=unused-argument
step_time = time.perf_counter() - self.start_time
num_tokens_per_device = state.num_tokens.clone()
# non data parallel groups have duplicated tokens, so we avoid double-counting
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
state.last_tokens_per_second = num_tokens_per_device / step_time
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs=None,
**kwargs,
): # pylint: disable=unused-argument
# after logging, clear the running metrics
state.last_tokens_per_second.zero_()
state.num_tokens = 0

View File

@@ -77,7 +77,7 @@ def resolve_dtype(cfg):
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
if cfg.bf16 and cfg.fp16 is not False:
cfg.fp16 = True
cfg.bf16 = False
else:

View File

@@ -26,7 +26,6 @@ from axolotl.utils.data.shared import (
save_preprocessed_dataset,
try_load_from_hub,
)
from axolotl.utils.data.streaming import wrap_streaming_sft_dataset
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
handle_long_seq_in_dataset,
@@ -74,7 +73,7 @@ def _prepare_standard_dataset(
tokenizer: PreTrainedTokenizer,
processor: ProcessorMixin | None,
preprocess_iterable: bool,
) -> tuple[Dataset | IterableDataset, Dataset | None, int, list[Prompter | None]]:
) -> tuple[Dataset, Dataset | None, int, list[Prompter | None]]:
"""Prepare standard (non-pretraining) datasets."""
def _load_datasets():
@@ -119,14 +118,7 @@ def _prepare_standard_dataset(
)
# Calculate total number of training steps
# For streaming datasets, we must use max_steps
if isinstance(train_dataset, IterableDataset):
if not cfg.max_steps:
raise ValueError(
"When using streaming datasets, you must set max_steps in your config"
)
total_num_steps = cfg.max_steps
elif cfg.max_steps:
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
)
@@ -350,18 +342,14 @@ def _load_raw_datasets(
dataset = handle_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
else:
dataset = handle_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
# Skip packing processing for streaming datasets - they handle it differently
if cfg.sample_packing and not isinstance(dataset, IterableDataset):
if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
# Skip saving for streaming datasets as they can't be cached
if not isinstance(dataset, IterableDataset):
# Save the prepared dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
# Save the prepared dataset
dataset_hash = generate_dataset_hash_from_config(
cfg, datasets_configs, tokenizer.name_or_path
)
save_preprocessed_dataset(cfg, dataset, dataset_hash, split)
return dataset, prompters
@@ -377,10 +365,8 @@ def _load_and_process_single_dataset(
) -> tuple[Dataset | IterableDataset, Prompter | None]:
"""Load and process a single dataset based on the passed config."""
# Load the dataset
# Use streaming if enabled in config or if using iterable preprocessing
use_streaming = cfg.streaming or preprocess_iterable
dataset = load_dataset_with_config(
dataset_config, cfg.hf_use_auth_token, streaming=use_streaming
dataset_config, cfg.hf_use_auth_token, streaming=preprocess_iterable
)
# Parse dataset type
@@ -405,63 +391,16 @@ def _load_and_process_single_dataset(
num_shards=dataset_config.shards, index=shards_idx
)
# For streaming datasets, we need to handle tokenization differently
if isinstance(dataset, IterableDataset):
# Use pretraining's approach for multipack streaming
if cfg.sample_packing:
# Create the dataset wrapper function once
def ds_wrapper_fn(dataset=None):
wrapped_dataset, prompter = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
return wrapped_dataset, prompter
# Use pretraining wrapper for efficient streaming SFT with packing
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
dataset_wrapper = wrap_pretraining_dataset(
dataset,
tokenizer,
cfg,
ds_wrapper_fn,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed,
buffer_size=cfg.pretrain_multipack_buffer_size,
)
else:
# Use regular streaming wrapper
dataset_wrapper = wrap_streaming_sft_dataset(
dataset,
tokenizer,
cfg,
dataset_config,
d_base_type,
d_prompt_style,
processor,
max_tokens=cfg.sequence_len,
buffer_size=10_000,
)
# For streaming, we don't have a specific prompter
dataset_prompter = None
else:
# Apply dataset wrapper for regular datasets
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
# Apply dataset wrapper
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
return dataset_wrapper, dataset_prompter

View File

@@ -524,9 +524,7 @@ def generate_dataset_hash_from_config(
return str(md5(config_str))
def merge_datasets(
datasets: list[Dataset | IterableDataset], cfg: DictDefault
) -> Dataset | IterableDataset:
def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
"""Merge multiple datasets into one with optional shuffling.
Args:
@@ -536,41 +534,6 @@ def merge_datasets(
Returns:
Merged dataset.
"""
# Check if we're dealing with streaming datasets
if any(isinstance(ds, IterableDataset) for ds in datasets):
# All datasets must be streaming for merging
if not all(isinstance(ds, IterableDataset) for ds in datasets):
raise ValueError(
"Cannot mix streaming and non-streaming datasets. "
"Either all datasets must be streaming or none."
)
if len(datasets) == 1:
ds = datasets[0]
# Streaming datasets handle shuffling differently
if cfg.shuffle_merged_datasets and not cfg.curriculum_sampling:
return ds.shuffle(seed=cfg.seed, buffer_size=10_000)
return ds
# Merge streaming datasets
LOG.info("Merging streaming datasets...")
from datasets import interleave_datasets
# For streaming, we interleave datasets instead of concatenating
merged_dataset = interleave_datasets(datasets)
if cfg.shuffle_merged_datasets:
LOG.debug("Shuffling merged streaming datasets...")
if cfg.curriculum_sampling:
LOG.warning(
"Shuffling merged datasets with curriculum sampling is not recommended. "
"This will randomize the order of samples."
)
merged_dataset = merged_dataset.shuffle(seed=cfg.seed, buffer_size=10_000)
return merged_dataset
# Original logic for non-streaming datasets
if len(datasets) == 1:
ds = datasets[0]

View File

@@ -1,150 +0,0 @@
"""Utilities for handling streaming datasets."""
import functools
from collections import defaultdict
from typing import Any, Dict, List
import numpy as np
from datasets import Dataset, IterableDataset
from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import add_position_ids
LOG = get_logger(__name__)
def wrap_streaming_sft_dataset(
dataset: IterableDataset,
tokenizer: PreTrainedTokenizerBase,
cfg,
dataset_config,
d_base_type: str,
d_prompt_style: str | None,
processor: Any | None,
max_tokens: int = 2048,
buffer_size: int = 10_000,
) -> IterableDataset:
"""
Wrap a streaming SFT dataset with tokenization and optional packing.
This is similar to wrap_pretraining_dataset but for SFT datasets.
Args:
dataset: The streaming dataset to wrap
tokenizer: Tokenizer to use
cfg: Configuration object
dataset_config: Dataset configuration
d_base_type: Base dataset type
d_prompt_style: Prompt style
processor: Optional processor for multimodal
max_tokens: Maximum sequence length
buffer_size: Buffer size for shuffling
Returns:
Wrapped streaming dataset ready for training
"""
# Import here to avoid circular imports
from axolotl.utils.data.wrappers import get_dataset_wrapper
# Apply shuffling if configured
if cfg.shuffle_merged_datasets:
LOG.info(f"Shuffling streaming dataset with buffer_size={buffer_size}")
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
# For streaming datasets, we need to get column names from the first sample
remove_columns = []
for first_row in dataset:
remove_columns = list(first_row.keys())
break
# Reset dataset after peeking
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=cfg.seed, buffer_size=buffer_size)
# Define the encoding function - always add position_ids for compatibility
if cfg.sample_packing:
# For sample packing, we need to handle position_ids
def encode_streaming_packed(examples: Dict[str, List]) -> Dict[str, List]:
"""Encode examples for streaming with sample packing."""
# Convert the batch dict to a temporary Dataset for processing
temp_dataset = Dataset.from_dict(examples)
# Apply the dataset wrapper to tokenize
wrapped_dataset, _ = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=temp_dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
# Convert to dict for processing
result = {}
if hasattr(wrapped_dataset, "to_dict"):
result = wrapped_dataset.to_dict()
else:
for key in wrapped_dataset.column_names:
result[key] = wrapped_dataset[key]
# Add position_ids using the existing function
result = add_position_ids(result)
# For multipack attention, we may need to drop attention_mask
if cfg.pretrain_multipack_attn and "attention_mask" in result:
del result["attention_mask"]
return result
encode_fn = encode_streaming_packed
else:
# Regular encoding without packing - still add position_ids for compatibility
def encode_streaming(examples: Dict[str, List]) -> Dict[str, List]:
"""Encode examples for streaming."""
# Convert the batch dict to a temporary Dataset for processing
temp_dataset = Dataset.from_dict(examples)
# Apply the dataset wrapper to tokenize
wrapped_dataset, _ = get_dataset_wrapper(
dataset_config=dataset_config,
tokenizer=tokenizer,
cfg=cfg,
dataset_base_type=d_base_type,
dataset=temp_dataset,
dataset_prompt_style=d_prompt_style,
processor=processor,
)
# Convert to dict format
result = {}
if hasattr(wrapped_dataset, "to_dict"):
result = wrapped_dataset.to_dict()
else:
for key in wrapped_dataset.column_names:
result[key] = wrapped_dataset[key]
# Add position_ids even without packing for compatibility
result = add_position_ids(result)
return result
encode_fn = encode_streaming
# Map the encoding function over the streaming dataset
dataset = dataset.map(
encode_fn,
batched=True,
batch_size=buffer_size,
remove_columns=remove_columns,
)
# Set format for PyTorch
dataset = dataset.with_format("torch")
return dataset

View File

@@ -178,8 +178,8 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def handle_long_seq_in_dataset(
dataset: Dataset | IterableDataset, sequence_len: int, cfg: DictDefault
) -> Dataset | IterableDataset:
dataset: Dataset, sequence_len: int, cfg: DictDefault
) -> Dataset:
"""Remove sequences longer than configured maximum from dataset.
Args:
@@ -190,14 +190,7 @@ def handle_long_seq_in_dataset(
Returns:
Filtered dataset with long sequences removed.
"""
# Streaming datasets don't support filtering the same way
if isinstance(dataset, IterableDataset):
LOG.info(
"Streaming dataset detected - long sequence filtering will be done on-the-fly"
)
return dataset
if not hasattr(dataset, "column_names") or "input_ids" not in dataset.column_names:
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
"expected for reward modeling."

View File

@@ -244,12 +244,6 @@ class AxolotlInputConfig(
dataloader_num_workers: int | None = None
dataloader_prefetch_factor: int | None = None
dataloader_drop_last: bool | None = None
streaming: bool | None = Field(
default=None,
json_schema_extra={
"description": "Enable streaming mode for training datasets to reduce memory usage and enable training on datasets larger than memory"
},
)
accelerator_config: dict[str, Any] | None = None
@@ -836,10 +830,15 @@ class AxolotlInputConfig(
include_tokens_per_second: bool | None = Field(
default=None,
json_schema_extra={
"description": "bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time."
"description": "bool of whether to report tokens per second at the end of training. This is not supported with pre-training datasets."
},
)
include_tkps: bool | None = Field(
default=None,
json_schema_extra={
"description": "bool of whether to report tokens per second during training by measuring throughput of non-padding tokens."
},
)
neftune_noise_alpha: float | None = Field(
default=None,
json_schema_extra={

View File

@@ -59,16 +59,21 @@ class ModelInputConfig(BaseModel):
processor_type: str | None = Field(
default=None, json_schema_extra={"description": "transformers processor class"}
)
tokenizer_save_jinja_files: bool | None = Field(
default=True, # match the default behavior from transformers
json_schema_extra={
"description": "Whether to save jinja files for tokenizer, transformers default is True"
},
)
trust_remote_code: bool | None = Field(
default=None,
json_schema_extra={"description": "Trust remote code for untrusted source"},
)
experimental_skip_move_to_device: bool | None = Field(
default=None,
default=True,
json_schema_extra={
"description": "Don't move the model to the device before sharding. "
"This is an experimental feature that may be included in the future as the default."
"description": "Don't move the model to the device before sharding. Set to `false` to revert to legacy behavior."
},
)

View File

@@ -1074,24 +1074,6 @@ class PretrainingValidationMixin:
data["accelerator_config"]["dispatch_batches"] = False
return data
@model_validator(mode="before")
@classmethod
def check_pretraining_split_batches_accelerate(cls, data):
# alternatively set ACCELERATE_SPLIT_BATCHES=False
if data.get("streaming"):
accelerator_config = data.get("accelerator_config", {})
if not accelerator_config:
data["accelerator_config"] = {
"split_batches": False,
"dispatch_batches": False,
}
else:
if accelerator_config.get("split_batches") is None:
data["accelerator_config"]["split_batches"] = False
if accelerator_config.get("dispatch_batches") is None:
data["accelerator_config"]["dispatch_batches"] = False
return data
class ModelCompatibilityValidationMixin:
"""Validation methods for specific model compatibility."""

View File

@@ -0,0 +1,63 @@
"""
e2e test for saving the tokenizer
"""
from unittest.mock import patch
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists
def test_tokenizer_no_save_jinja_files(temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"chat_template": "chatml",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_first_step": False,
"fp16": False,
"tokenizer_save_jinja_files": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
with patch("axolotl.train.execute_training"):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
with open(f"{temp_dir}/tokenizer_config.json", "r", encoding="utf-8") as f:
tokenizer_config = f.read()
assert "chat_template" in tokenizer_config