Compare commits

...

8 Commits

Author SHA1 Message Date
NanoCode012
e37a768960 feat: add baseten to lmeval 2025-08-29 18:02:26 +07:00
Wing Lian
6afba3871d Add support for PyTorch 2.8.0 (#3106)
* Add support for PyTorch 2.8.0

* loosen triton requirements

* handle torch 2.8.0 in setup.py

* fix versions

* no vllm for torch 2.8.0

* remove comment

Co-authored-by: NanoCode012 <nano@axolotl.ai>

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-08-28 09:10:40 -04:00
Dan Saunders
dc338c3b0e Update .coderabbit.yaml (#3109) [skip ci]
Oops, should be false.
2025-08-27 09:50:52 -04:00
salman
d0d2fc5606 Tokens per second logging [skip-e2e] (#3072) 2025-08-27 09:10:14 +01:00
Wing Lian
e1131e9619 make always skip_move_to_device default as true (#3084) 2025-08-26 09:30:22 -04:00
Wing Lian
c4c4b90638 add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json (#3093)
* add tokenizer_save_jinja_files to keep legacy behavior of including chat template in tokenizer_config.json

* fix test import
2025-08-26 09:30:04 -04:00
Wing Lian
0e9945e3b9 deploy training jobs to baseten w truss in axolotl cli (#3086) [skip ci]
* deploy training jobs to baseten w truss in axolotl cli

* cleanup
2025-08-26 09:29:50 -04:00
NanoCode012
0de254a0d0 feat: add gemma3_text attention handling for lora kernels (#3103) 2025-08-26 16:47:26 +07:00
31 changed files with 534 additions and 47 deletions

View File

@@ -12,6 +12,6 @@ reviews:
auto_review:
enabled: true
drafts: false
auto_incremental_review: true
auto_incremental_review: false
chat:
auto_reply: true

View File

@@ -36,6 +36,11 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -110,6 +115,11 @@ jobs:
python_version: "3.11"
pytorch: 2.7.1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -169,6 +179,12 @@ jobs:
pytorch: 2.7.1
axolotl_extras: vllm
is_latest: true
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
is_latest:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -33,13 +33,6 @@ jobs:
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
@@ -47,6 +40,13 @@ jobs:
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:

View File

@@ -55,7 +55,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
timeout-minutes: 20
steps:
@@ -130,7 +130,7 @@ jobs:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0", "2.7.0", "2.7.1"]
pytorch_version: ["2.6.0", "2.7.1", "2.8.0"]
timeout-minutes: 20
steps:
@@ -240,7 +240,7 @@ jobs:
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
dockerfile: "Dockerfile-uv.jinja"
@@ -298,6 +298,12 @@ jobs:
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
- cuda: 128
cuda_version: 12.8.1
python_version: "3.11"
pytorch: 2.8.0
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -334,10 +340,10 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.6.0
pytorch: 2.7.1
num_gpus: 1
axolotl_extras:
steps:

View File

@@ -0,0 +1,10 @@
provider: baseten
project_name:
secrets:
- HF_TOKEN
- WANDB_API_KEY
gpu: h100
gpu_count: 8
node_count: 1

View File

@@ -40,7 +40,7 @@
"%%capture\n",
"# This step can take ~5-10 minutes to install dependencies\n",
"!pip install --no-build-isolation axolotl[flash-attn]>=0.9.1\n",
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8\""
"!pip install \"cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5\""
]
},
{

View File

@@ -2,8 +2,7 @@
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.47.0
# triton 3.4.0 is not compatible with CCE
triton>=3.0.0,<3.4.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3

View File

@@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else ""
print(
UNINSTALL_PREFIX
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"'
+ f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"'
)

View File

@@ -64,7 +64,9 @@ def parse_requirements(extras_require_map):
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 7):
if (major, minor) >= (2, 8):
pass
elif (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.30")

View File

@@ -7,6 +7,8 @@ from typing import Literal
import yaml
from axolotl.cli.cloud.base import Cloud
from axolotl.cli.cloud.baseten import BasetenCloud
from axolotl.cli.cloud.modal_ import ModalCloud
from axolotl.utils.dict import DictDefault
@@ -38,8 +40,15 @@ def do_cli_train(
cwd=None,
**kwargs,
) -> None:
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
cloud_cfg: DictDefault = load_cloud_cfg(cloud_config)
provider = cloud_cfg.provider or "modal"
cloud: Cloud | None
if provider == "modal":
cloud = ModalCloud(cloud_cfg)
elif provider == "baseten":
cloud = BasetenCloud(cloud_cfg.to_dict())
else:
raise ValueError(f"Unsupported cloud provider: {provider}")
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
local_dirs = {}
@@ -58,8 +67,16 @@ def do_cli_lm_eval(
cloud_config: Path | str,
config: Path | str,
) -> None:
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
cloud_cfg: DictDefault = load_cloud_cfg(cloud_config)
provider = cloud_cfg.provider or "modal"
cloud: Cloud | None
if provider == "modal":
cloud = ModalCloud(cloud_cfg)
elif provider == "baseten":
cloud = BasetenCloud(cloud_cfg.to_dict())
else:
raise ValueError(f"Unsupported cloud provider: {provider}")
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.lm_eval(config_yaml)

View File

@@ -0,0 +1,68 @@
"""Baseten Cloud CLI"""
import shutil
import subprocess # nosec B404
import tempfile
from os.path import dirname
from typing import Literal
import yaml
from axolotl.cli.cloud.base import Cloud
class BasetenCloud(Cloud):
"""Baseten Cloud Axolotl CLI"""
def __init__(self, config: dict):
self.config = config
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
raise NotImplementedError(
"Separate preprocess function for Baseten is not "
"implemented and will happen during hte train step."
)
def train(
self,
config_yaml: str,
launcher: Literal["accelerate", "torchrun", "python"] = "accelerate",
launcher_args: list[str] | None = None,
local_dirs: dict[str, str] | None = None, # pylint: disable=unused-argument
**kwargs,
):
with tempfile.TemporaryDirectory() as tmp_dir:
config = self.config.copy()
config["launcher"] = launcher
config["launcher_args"] = launcher_args
with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout:
yaml.dump(config, cloud_fout)
with open(tmp_dir + "/train.yaml", "w", encoding="utf-8") as config_fout:
config_fout.write(config_yaml)
shutil.copyfile(dirname(__file__) + "/template/run.sh", tmp_dir + "/run.sh")
shutil.copyfile(
dirname(__file__) + "/template/train_sft.py", tmp_dir + "/train_sft.py"
)
subprocess.run( # nosec B603 B607
["truss", "train", "push", "train_sft.py"], cwd=tmp_dir, check=False
)
def lm_eval(
self,
config_yaml: str,
):
with tempfile.TemporaryDirectory() as tmp_dir:
config = self.config.copy()
with open(tmp_dir + "/cloud.yaml", "w", encoding="utf-8") as cloud_fout:
yaml.dump(config, cloud_fout)
with open(tmp_dir + "/eval.yaml", "w", encoding="utf-8") as config_fout:
config_fout.write(config_yaml)
shutil.copyfile(
dirname(__file__) + "/template/eval.sh", tmp_dir + "/eval.sh"
)
shutil.copyfile(
dirname(__file__) + "/template/eval_sft.py", tmp_dir + "/eval_sft.py"
)
subprocess.run( # nosec B603 B607
["truss", "train", "push", "eval_sft.py"], cwd=tmp_dir, check=False
)

View File

@@ -0,0 +1,8 @@
#!/bin/bash
set -eux
export NCCL_SOCKET_IFNAME="^docker0,lo"
export NCCL_IB_DISABLE=0
export NCCL_TIMEOUT=1800000
axolotl lm-eval eval.yaml

View File

@@ -0,0 +1,81 @@
"""
Baseten Training Script for Axolotl
"""
# pylint: skip-file
import yaml
from truss.base import truss_config
# Import necessary classes from the Baseten Training SDK
from truss_train import definitions
cloud_config = yaml.safe_load(open("cloud.yaml", "r"))
gpu = cloud_config.get("gpu", "h100")
gpu_count = (
1 # int(cloud_config.get("gpu_count", 1)) # only single GPU supported at the moment
)
node_count = (
1 # int(cloud_config.get("node_count", 1)) # only single node support for lmeval
)
project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project"
secrets = cloud_config.get("secrets", [])
# launcher = cloud_config.get("launcher", "accelerate")
# launcher_args = cloud_config.get("launcher_args", [])
script_name = "eval.sh"
# launcher_args_str = ""
# if launcher_args:
# launcher_args_str = "-- " + " ".join(launcher_args)
# 1. Define a base image for your training job
# must use torch 2.7.0 for vllm
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
# 2. Define the Runtime Environment for the Training Job
# This includes start commands and environment variables.a
# Secrets from the baseten workspace like API keys are referenced using
# `SecretReference`.
env_vars = {
# "AXOLOTL_LAUNCHER": launcher,
# "AXOLOTL_LAUNCHER_ARGS": launcher_args_str,
}
for secret_name in secrets:
env_vars[secret_name] = definitions.SecretReference(name=secret_name)
training_runtime = definitions.Runtime(
start_commands=[ # Example: list of commands to run your training script
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
],
environment_variables=env_vars,
cache_config=definitions.CacheConfig(
enabled=True,
),
checkpointing_config=definitions.CheckpointingConfig(
enabled=True,
),
)
# 3. Define the Compute Resources for the Training Job
training_compute = definitions.Compute(
node_count=node_count,
accelerator=truss_config.AcceleratorSpec(
accelerator=truss_config.Accelerator.H100,
count=gpu_count,
),
)
# 4. Define the Training Job
# This brings together the image, compute, and runtime configurations.
my_training_job = definitions.TrainingJob(
image=definitions.Image(base_image=BASE_IMAGE),
compute=training_compute,
runtime=training_runtime,
)
# This config will be pushed using the Truss CLI.
# The association of the job to the project happens at the time of push.
first_project_with_job = definitions.TrainingProject(
name=project_name, job=my_training_job
)

View File

@@ -0,0 +1,9 @@
#!/bin/bash
set -eux
export NCCL_SOCKET_IFNAME="^docker0,lo"
export NCCL_IB_DISABLE=0
export NCCL_TIMEOUT=1800000
axolotl preprocess train.yaml
axolotl train train.yaml --launcher ${AXOLOTL_LAUNCHER} ${AXOLOTL_LAUNCHER_ARGS}

View File

@@ -0,0 +1,77 @@
"""
Baseten Training Script for Axolotl
"""
# pylint: skip-file
import yaml
from truss.base import truss_config
# Import necessary classes from the Baseten Training SDK
from truss_train import definitions
cloud_config = yaml.safe_load(open("cloud.yaml", "r"))
gpu = cloud_config.get("gpu", "h100")
gpu_count = int(cloud_config.get("gpu_count", 1))
node_count = int(cloud_config.get("node_count", 1))
project_name = cloud_config.get("project_name", "axolotl-project") or "axolotl-project"
secrets = cloud_config.get("secrets", [])
launcher = cloud_config.get("launcher", "accelerate")
launcher_args = cloud_config.get("launcher_args", [])
script_name = "run.sh"
launcher_args_str = ""
if launcher_args:
launcher_args_str = "-- " + " ".join(launcher_args)
# 1. Define a base image for your training job
# must use torch 2.7.0 for vllm
BASE_IMAGE = "axolotlai/axolotl:main-py3.11-cu126-2.7.1"
# 2. Define the Runtime Environment for the Training Job
# This includes start commands and environment variables.a
# Secrets from the baseten workspace like API keys are referenced using
# `SecretReference`.
env_vars = {
"AXOLOTL_LAUNCHER": launcher,
"AXOLOTL_LAUNCHER_ARGS": launcher_args_str,
}
for secret_name in secrets:
env_vars[secret_name] = definitions.SecretReference(name=secret_name)
training_runtime = definitions.Runtime(
start_commands=[ # Example: list of commands to run your training script
f"/bin/sh -c 'chmod +x ./{script_name} && ./{script_name}'"
],
environment_variables=env_vars,
cache_config=definitions.CacheConfig(
enabled=True,
),
checkpointing_config=definitions.CheckpointingConfig(
enabled=True,
),
)
# 3. Define the Compute Resources for the Training Job
training_compute = definitions.Compute(
node_count=node_count,
accelerator=truss_config.AcceleratorSpec(
accelerator=truss_config.Accelerator.H100,
count=gpu_count,
),
)
# 4. Define the Training Job
# This brings together the image, compute, and runtime configurations.
my_training_job = definitions.TrainingJob(
image=definitions.Image(base_image=BASE_IMAGE),
compute=training_compute,
runtime=training_runtime,
)
# This config will be pushed using the Truss CLI.
# The association of the job to the project happens at the time of push.
first_project_with_job = definitions.TrainingProject(
name=project_name, job=my_training_job
)

View File

@@ -43,7 +43,10 @@ def do_merge_lora(*, cfg: DictDefault) -> None:
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
tokenizer.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
if processor:
processor.save_pretrained(str(Path(cfg.output_dir) / "merged"))

View File

@@ -84,5 +84,6 @@ def do_quantize(
str(Path(output_dir) / "quantized"),
safe_serialization=False,
progressbar=True,
save_jinja_files=cfg.tokenizer_save_jinja_files,
)
LOG.info(f"Quantized model saved to: {str(Path(output_dir) / 'quantized')}...")

View File

@@ -24,9 +24,7 @@ from pathlib import Path
from typing import Any
import torch
from transformers import (
TrainerCallback,
)
from transformers import TrainerCallback
from transformers.trainer_pt_utils import AcceleratorConfig
from axolotl.integrations.base import PluginManager
@@ -38,6 +36,7 @@ from axolotl.utils.callbacks import (
SaveModelOnFirstStepCallback,
)
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.callbacks.tokens_per_second import TokensPerSecondCallback
from axolotl.utils.distributed import build_parallelism_config
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
@@ -146,6 +145,12 @@ class TrainerBuilderBase(abc.ABC):
profiler_steps_start=self.cfg.profiler_steps_start,
)
)
if self.cfg.include_tkps:
callbacks.append(
TokensPerSecondCallback(
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
)
)
return callbacks
@@ -512,6 +517,7 @@ class TrainerBuilderBase(abc.ABC):
self.cfg.eval_batch_size
)
training_args_kwargs["include_tkps"] = self.cfg.include_tkps
training_args_kwargs["max_steps"] = self.cfg.max_steps or total_num_steps or -1
training_args_kwargs["num_train_epochs"] = self.cfg.num_epochs

View File

@@ -404,6 +404,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
# if the trainer has the `axolotl_cfg` property, set it
if hasattr(trainer, "axolotl_cfg"):
trainer.axolotl_cfg = self.cfg
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)

View File

@@ -42,6 +42,7 @@ from axolotl.core.trainers.utils import (
)
from axolotl.utils import get_not_null
from axolotl.utils.bench import get_gpu_memory_usage
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -63,6 +64,15 @@ class AxolotlTrainer(
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
_axolotl_cfg: DictDefault | None = None
@property
def axolotl_cfg(self):
return self._axolotl_cfg
@axolotl_cfg.setter
def axolotl_cfg(self, cfg):
self._axolotl_cfg = cfg
def __init__(
self,
@@ -78,7 +88,6 @@ class AxolotlTrainer(
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
@@ -327,6 +336,17 @@ class AxolotlTrainer(
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
# track number of tokens for tokens per second calculation
if self.args.include_tkps:
inputs_key = "labels" if "labels" in inputs else "input_ids"
if hasattr(self.state, "num_tokens"):
self.state.num_tokens = (
self.state.num_tokens + (inputs[inputs_key] != -100).sum()
)
else:
self.state.num_tokens = (inputs[inputs_key] != -100).sum()
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
@@ -526,9 +546,6 @@ class AxolotlTrainer(
super().create_accelerator_and_postprocess()
# now we need to put parallelism_config back on the PartialState since we rely on that info in other places
# PartialState().parallelism_config = self.accelerator.state.parallelism_config
if self.is_fsdp_enabled:
if (
"limit_all_gathers" in self.args.fsdp_config
@@ -576,12 +593,19 @@ class AxolotlTrainer(
# Add memory usage
try:
active, allocated, reserved = get_gpu_memory_usage()
logs["memory/max_mem_active(gib)"] = round(active, 2)
logs["memory/max_mem_allocated(gib)"] = round(allocated, 2)
logs["memory/device_mem_reserved(gib)"] = round(reserved, 2)
logs["memory/max_active (GiB)"] = round(active, 2)
logs["memory/max_allocated (GiB)"] = round(allocated, 2)
logs["memory/device_reserved (GiB)"] = round(reserved, 2)
except (ValueError, TypeError, FileNotFoundError):
pass
if self.args.include_tkps and train_eval == "train":
# each rank will log its own tokens per second
# for logging_steps > 1 we obtain a moving average of this metric
logs["tokens_per_second_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
del self._stored_metrics[train_eval]
return super().log(logs, start_time)
@@ -657,6 +681,11 @@ class AxolotlTrainer(
LOG.info(
"Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`"
)
self.data_collator.tokenizer.save_pretrained(output_dir)
save_jinja_files = True
if self.axolotl_cfg:
save_jinja_files = self.axolotl_cfg.tokenizer_save_jinja_files
self.data_collator.tokenizer.save_pretrained(
output_dir, save_jinja_files=save_jinja_files
)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

View File

@@ -49,6 +49,12 @@ class AxolotlTrainingMixins:
default=False,
metadata={"help": "Use real batches for efficient training."},
)
include_tkps: bool = field(
default=True,
metadata={
"help": "Whether to include tokens per second in the training metrics."
},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},

View File

@@ -19,7 +19,7 @@ python scripts/cutcrossentropy_install.py | sh
- If you are installing from pip
```bash
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"
```
## Usage

View File

@@ -35,7 +35,7 @@ LOG = get_logger(__name__)
_CCE_INSTALL_MESSAGE = (
"Please install Axolotl's fork of cut_cross_entropy with transformers support using "
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@0ee9ee8"`'
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@c6a32c5"`'
)

View File

@@ -149,6 +149,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
return MistralAttention
if model_type == "gemma3_text":
from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention
return Gemma3Attention
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"

View File

@@ -416,7 +416,9 @@ def save_initial_configs(
# Pre-save the tokenizer and model configs
LOG.info(f"Pre-saving tokenizer to {cfg.output_dir}...")
tokenizer.save_pretrained(str(output_dir))
tokenizer.save_pretrained(
str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files
)
if hasattr(model, "config"):
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
model.config.save_pretrained(str(output_dir))
@@ -592,6 +594,9 @@ def train(
# Save the trained model and cleanup
save_trained_model(cfg, trainer, model, safe_serialization)
tokenizer.save_pretrained(
str(Path(cfg.output_dir)), save_jinja_files=cfg.tokenizer_save_jinja_files
)
create_model_card(cfg, trainer)
if not cfg.use_ray:
cleanup_distributed()

View File

@@ -60,13 +60,14 @@ def gpu_memory_usage_all(device=0):
active = torch.cuda.memory_stats().get("active_bytes.all.peak", 0) / 1024.0**3
allocated = torch.cuda.max_memory_allocated(device) / 1024.0**3
reserved = torch.cuda.max_memory_reserved(device) / 1024.0**3
torch.cuda.reset_peak_memory_stats(device)
return active, allocated, reserved
def mps_memory_usage_all():
usage = torch.mps.current_allocated_memory() / 1024.0**3
reserved = torch.mps.driver_allocated_memory() / 1024.0**3
return usage, reserved - usage, 0
active = torch.mps.current_allocated_memory() / 1024.0**3
allocated = torch.mps.driver_allocated_memory() / 1024.0**3
return active, allocated, 0
def npu_memory_usage_all(device=0):

View File

@@ -0,0 +1,62 @@
"""A callback for calculating tokens per second during training."""
import time
import torch
from transformers import (
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
class TokensPerSecondCallback(TrainerCallback):
"""
A callback to measure and log tokens per second during training.
"""
def __init__(self, tensor_parallel_size, context_parallel_size):
super().__init__()
self.step_time = 0.0
self.start_time = 0.0
self.non_data_parallel_size = 1
if tensor_parallel_size is not None:
self.non_data_parallel_size *= tensor_parallel_size
if context_parallel_size is not None:
self.non_data_parallel_size *= context_parallel_size
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
): # pylint: disable=unused-argument
self.start_time = time.perf_counter()
state.last_tokens_per_second = torch.zeros(1)
def on_step_end(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
): # pylint: disable=unused-argument
step_time = time.perf_counter() - self.start_time
num_tokens_per_device = state.num_tokens.clone()
# non data parallel groups have duplicated tokens, so we avoid double-counting
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
state.last_tokens_per_second = num_tokens_per_device / step_time
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs=None,
**kwargs,
): # pylint: disable=unused-argument
# after logging, clear the running metrics
state.last_tokens_per_second.zero_()
state.num_tokens = 0

View File

@@ -77,7 +77,7 @@ def resolve_dtype(cfg):
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
if cfg.bf16 and cfg.fp16 is not False:
cfg.fp16 = True
cfg.bf16 = False
else:

View File

@@ -830,10 +830,15 @@ class AxolotlInputConfig(
include_tokens_per_second: bool | None = Field(
default=None,
json_schema_extra={
"description": "bool of whether to include tokens trainer per second in the training metrics. This iterates over the entire dataset once, so it takes some time."
"description": "bool of whether to report tokens per second at the end of training. This is not supported with pre-training datasets."
},
)
include_tkps: bool | None = Field(
default=None,
json_schema_extra={
"description": "bool of whether to report tokens per second during training by measuring throughput of non-padding tokens."
},
)
neftune_noise_alpha: float | None = Field(
default=None,
json_schema_extra={

View File

@@ -59,16 +59,21 @@ class ModelInputConfig(BaseModel):
processor_type: str | None = Field(
default=None, json_schema_extra={"description": "transformers processor class"}
)
tokenizer_save_jinja_files: bool | None = Field(
default=True, # match the default behavior from transformers
json_schema_extra={
"description": "Whether to save jinja files for tokenizer, transformers default is True"
},
)
trust_remote_code: bool | None = Field(
default=None,
json_schema_extra={"description": "Trust remote code for untrusted source"},
)
experimental_skip_move_to_device: bool | None = Field(
default=None,
default=True,
json_schema_extra={
"description": "Don't move the model to the device before sharding. "
"This is an experimental feature that may be included in the future as the default."
"description": "Don't move the model to the device before sharding. Set to `false` to revert to legacy behavior."
},
)

View File

@@ -0,0 +1,63 @@
"""
e2e test for saving the tokenizer
"""
from unittest.mock import patch
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists
def test_tokenizer_no_save_jinja_files(temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_type": "AutoTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.02,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"chat_template": "chatml",
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch_fused",
"lr_scheduler": "cosine",
"max_steps": 5,
"save_first_step": False,
"fp16": False,
"tokenizer_save_jinja_files": False,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)
with patch("axolotl.train.execute_training"):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
with open(f"{temp_dir}/tokenizer_config.json", "r", encoding="utf-8") as f:
tokenizer_config = f.read()
assert "chat_template" in tokenizer_config