Compare commits

..

1 Commits

Author SHA1 Message Date
coderabbitai[bot]
0fccbadb79 📝 Add docstrings to 202512-raise_on_drop
Docstrings generation was requested by @kallewoof.

* https://github.com/axolotl-ai-cloud/axolotl/pull/3321#issuecomment-3668489902

The following files were modified:

* `src/axolotl/utils/data/utils.py`
* `src/axolotl/utils/trainer.py`
2025-12-18 05:49:01 +00:00
40 changed files with 107 additions and 826 deletions

View File

@@ -12,9 +12,6 @@ jobs:
build-deploy:
runs-on: ubuntu-latest
steps:
- name: cleanup node
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
- name: Check out repository
uses: actions/checkout@v4
- name: Set up Quarto

View File

@@ -11,7 +11,6 @@ on:
- '_quarto.yml'
- docs/scripts/generate_config_docs.py
- src/axolotl/utils/schemas/**.py
- .github/workflows/preview-docs.yml
permissions:
checks: write
@@ -28,10 +27,6 @@ jobs:
runs-on: ubuntu-latest
if: ${{ !github.event.pull_request.draft }}
steps:
- name: cleanup node
run: |
sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL
- name: Check out repository
uses: actions/checkout@v4
with:

View File

@@ -114,7 +114,7 @@ jobs:
- name: Run tests
run: |
df -h
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
df -h
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
df -h
@@ -196,7 +196,7 @@ jobs:
- name: Run tests
run: |
pytest -v --durations=10 -n4 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ --ignore=tests/monkeypatch/ tests/ --cov=axolotl --cov-report=xml
pytest -v --durations=10 tests/monkeypatch/ --cov=axolotl --cov-append --cov-report=xml
pytest -v --durations=10 tests/cli/

View File

@@ -1,48 +0,0 @@
# Finetune GLM4.5 with Axolotl
[UNSTABLE]
```bash
# LoRA SFT (4xH200 @ 84GB/GPU)
axolotl train examples/glm45/glm4.5-lora-fsdp2.yaml
# FFT SFT (4xH200)
# Checkpointing error on backward pass
# Without checkpointing => OOM
axolotl train examples/glm45/glm4.5-fft-fsdp2.yaml
```
## Dataset
In addition to normal OpenAI Messages format, GLM4.5 support an extra parameter for thinking in assistant section.
```json
{
"role": "assistant",
"reasoning_content": "...", // or have </think>...</think> in `content`
"content": "...",
}
```
Note:
- The role name for tools in this template is `tool`.
- You will see this Axolotl WARNING. This is to be as expected as the template does not use EOS.
```bash
EOS token '<|endoftext|>' not found in chat_template. Please check if your template/EOS token is correct.
```
- Make sure you set the below extra attributes if needed
```yaml
datasets:
- path: ...
type: chat_template
message_property_mappings:
role: role
content: content
# tool_calls: tool_calls # uncomment if using tools
# reasoning_content: reasoning_content # uncomment if have reasoning
# Uncomment if training on tool role (you would rarely if ever need this)
# eot_tokens:
# - <|observation|>
```

View File

@@ -1,59 +0,0 @@
base_model: zai-org/GLM-4.5-Air
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: winglian/pirate-ultrachat-10k
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/qlora-out
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
# gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Glm4MoeDecoderLayer
state_dict_type: SHARDED_STATE_DICT
reshard_after_forward: true
activation_checkpointing: true

View File

@@ -1,74 +0,0 @@
base_model: zai-org/GLM-4.5-Air
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
plugins:
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
experimental_skip_move_to_device: true # prevent OOM by NOT putting model to GPU before sharding
datasets:
- path: winglian/pirate-ultrachat-10k
type: chat_template
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/qlora-out
adapter: lora
lora_model_dir:
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch_4bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
# gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: Glm4MoeDecoderLayer
state_dict_type: SHARDED_STATE_DICT
reshard_after_forward: true
# activation_checkpointing: false

View File

@@ -32,10 +32,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1

View File

@@ -28,10 +28,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1

View File

@@ -29,10 +29,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1

View File

@@ -28,10 +28,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 2
micro_batch_size: 1
num_epochs: 1

View File

@@ -41,10 +41,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1

View File

@@ -41,10 +41,6 @@ wandb_watch:
wandb_name:
wandb_log_model:
trackio_project_name:
trackio_run_name:
trackio_space_id:
gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1

View File

@@ -29,6 +29,7 @@ flex_attention: true
flex_attn_compile_kwargs:
dynamic: false
mode: max-autotune-no-cudagraphs
save_strategy: no
torch_compile: true
wandb_project:

View File

@@ -1,70 +0,0 @@
base_model: Qwen/Qwen2.5-0.5B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Use random initialization for fair comparison
reinit_weights: true
load_in_8bit: false
load_in_4bit: false
strict: false
# Pretraining dataset
pretraining_dataset:
- path: allenai/c4
name: en
type: pretrain
split: train
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/compare-adamw-pretrain
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
wandb_project: dist_muon
wandb_entity:
wandb_watch:
wandb_name: adamw
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 4
num_epochs: 1
max_steps: 305
# AdamW optimizer settings (standard LR for AdamW)
optimizer: adamw_torch_fused
learning_rate: 0.0002
weight_decay: 0.01
lr_scheduler: cosine
train_on_inputs: true
group_by_length: false
bf16: auto
fp16: false
tf32: false
gradient_checkpointing: false
logging_steps: 1
flash_attention: true
warmup_steps: 10
evals_per_epoch: 0
saves_per_epoch: 1
# Reproducibility
seed: 42
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: false
fsdp_reshard_after_forward: true
special_tokens:

View File

@@ -1,70 +0,0 @@
base_model: Qwen/Qwen2.5-0.5B
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer
# Use random initialization for fair comparison
reinit_weights: true
load_in_8bit: false
load_in_4bit: false
strict: false
# Pretraining dataset
pretraining_dataset:
- path: allenai/c4
name: en
type: pretrain
split: train
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./outputs/compare-muon-pretrain
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
wandb_project: dist_muon
wandb_entity:
wandb_watch:
wandb_name: muon
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 4
num_epochs: 1
max_steps: 305
# Muon optimizer settings
optimizer: muon
learning_rate: 0.02
weight_decay: 0.01
lr_scheduler: cosine
train_on_inputs: true
group_by_length: false
bf16: auto
fp16: false
tf32: false
gradient_checkpointing: false
logging_steps: 1
flash_attention: true
warmup_steps: 10
evals_per_epoch: 0
saves_per_epoch: 1
# Reproducibility
seed: 42
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: false
fsdp_reshard_after_forward: true
special_tokens:

View File

@@ -20,16 +20,15 @@ deepspeed>=0.17.0
trl==0.25.0
hf_xet==1.2.0
kernels>=0.9.0
trackio>=0.13.0
typing_extensions>=4.14.0
trackio
optimum==1.16.2
hf_transfer
sentencepiece
gradio>=6.2.0,<7.0
gradio==5.49.1
modal==1.0.2
pydantic>=2.10.6,<2.12
pydantic>=2.10.6
addict
fire
PyYAML>=6.0
@@ -68,7 +67,8 @@ openenv-core==0.1.0
schedulefree==1.4.1
axolotl-contribs-lgpl==0.0.7
axolotl-contribs-mit==0.0.6
axolotl-contribs-mit==0.0.5
# telemetry
posthog==6.7.11

View File

@@ -26,7 +26,6 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.tee import prepare_debug_log
from axolotl.utils.trackio_ import setup_trackio_env_vars
from axolotl.utils.trainer import prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
@@ -228,7 +227,6 @@ def load_cfg(
cfg,
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"fp8": compute_supports_fp8(),
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},
@@ -247,7 +245,6 @@ def load_cfg(
setup_wandb_env_vars(cfg)
setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
setup_trackio_env_vars(cfg)
plugin_set_cfg(cfg)
TELEMETRY_MANAGER.send_event(event_type="config-processed", properties=cfg)
@@ -262,11 +259,3 @@ def load_cfg(
)
return cfg
def compute_supports_fp8() -> bool:
try:
compute_capability = torch.cuda.get_device_capability()
return compute_capability >= (9, 0)
except RuntimeError:
return False

View File

@@ -288,8 +288,8 @@ def do_inference_gradio(
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
)
demo.launch(
footer_links=["gradio", "settings"],
demo.queue().launch(
show_api=False,
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),

View File

@@ -366,8 +366,8 @@ def launch_diffusion_gradio_ui(
outputs=[masked_preview, html_out],
)
demo.launch(
footer_links=["gradio", "settings"],
demo.queue().launch(
show_api=False,
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),

View File

@@ -14,7 +14,6 @@ MOE_ARCH_BLOCK = {
"qwen3_moe": "Qwen3MoeSparseMoeBlock",
"qwen3_vl_moe": "Qwen3VLMoeTextSparseMoeBlock",
"deepseek_v2": "DeepseekV2MoE",
"glm4_moe": "Glm4MoeMoE",
"deepseek_v3": "DeepseekV3MoE",
"gpt_oss": "GptOssDecoderLayer",
"lfm2_moe": "Lfm2MoeSparseMoeBlock",

View File

@@ -35,7 +35,6 @@ from axolotl.utils import (
is_comet_available,
is_mlflow_available,
is_opentelemetry_available,
is_trackio_available,
)
from axolotl.utils.callbacks import (
GCCallback,
@@ -148,14 +147,6 @@ class TrainerBuilderBase(abc.ABC):
callbacks.append(
SaveAxolotlConfigtoCometCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_trackio and is_trackio_available():
from axolotl.utils.callbacks.trackio_ import (
SaveAxolotlConfigtoTrackioCallback,
)
callbacks.append(
SaveAxolotlConfigtoTrackioCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_otel_metrics and is_opentelemetry_available():
from axolotl.utils.callbacks.opentelemetry import (
OpenTelemetryMetricsCallback,
@@ -290,22 +281,11 @@ class TrainerBuilderBase(abc.ABC):
adam_kwargs["eps"] = training_args_kwargs.get("adam_epsilon")
if self.cfg.optimizer == "muon":
_, device_mesh = build_parallelism_config(self.cfg)
if device_mesh is not None:
from axolotl.contribs.mit.muon.dist_muon import (
DistMuonOptimizerFactory,
)
optimizer_cls = DistMuonOptimizerFactory
optimizer_kwargs["device_mesh"] = device_mesh
else:
from axolotl.contribs.mit.muon import (
MuonOptimizerFactory,
)
optimizer_cls = MuonOptimizerFactory
from axolotl.contribs.mit.muon import (
MuonOptimizerFactory,
)
optimizer_cls = MuonOptimizerFactory
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "dion":
from axolotl.contribs.mit.dion import (
@@ -443,8 +423,6 @@ class TrainerBuilderBase(abc.ABC):
report_to.append("tensorboard")
if self.cfg.use_comet:
report_to.append("comet_ml")
if self.cfg.use_trackio:
report_to.append("trackio")
training_args_kwargs["report_to"] = report_to
@@ -452,8 +430,6 @@ class TrainerBuilderBase(abc.ABC):
training_args_kwargs["run_name"] = self.cfg.wandb_name
elif self.cfg.use_mlflow:
training_args_kwargs["run_name"] = self.cfg.mlflow_run_name
elif self.cfg.use_trackio:
training_args_kwargs["run_name"] = self.cfg.trackio_run_name
else:
training_args_kwargs["run_name"] = None

View File

@@ -2,7 +2,6 @@
from __future__ import annotations
import math
import os
from collections import defaultdict
from functools import partial, wraps
@@ -604,7 +603,6 @@ class AxolotlTrainer(
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
metric_ndigits = int(os.getenv("AXOLOTL_METRIC_NDIGITS", "5"))
for key, metric_data in self._stored_metrics[train_eval].items():
values = torch.tensor(metric_data["values"]) # type: ignore[arg-type]
@@ -615,18 +613,7 @@ class AxolotlTrainer(
raise NotImplementedError(
"Metric reduction must be one of [mean, min, max, sum]"
)
logs[key] = round(fn(values).item(), metric_ndigits)
if "loss" in logs:
try:
logs["ppl"] = round(math.exp(logs["loss"]), metric_ndigits)
except OverflowError:
logs["ppl"] = float("inf")
if "eval_loss" in logs:
try:
logs["eval_ppl"] = round(math.exp(logs["eval_loss"]), metric_ndigits)
except OverflowError:
logs["eval_ppl"] = float("inf")
logs[key] = round(fn(values).item(), 4)
if is_main_process():
# Add memory usage

View File

@@ -36,6 +36,4 @@ class DPOStrategy:
training_args_kwargs["dpo_norm_loss"] = cfg.dpo_norm_loss
if cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs["use_logits_to_keep"] = cfg.dpo_use_logits_to_keep
if cfg.dpo_use_liger_kernel is not None:
training_args_kwargs["use_liger_kernel"] = cfg.dpo_use_liger_kernel
return training_args_kwargs

View File

@@ -44,7 +44,6 @@ plugins:
- gemma3n_text
- glm
- glm4
- glm_moe
- glm4_moe
- glm4v
- glm4v_moe

View File

@@ -21,7 +21,7 @@ class DenseMixerPlugin(BasePlugin):
if cfg.dense_mixer:
if not importlib.util.find_spec("densemixer"):
raise RuntimeError(
"DenseMixer is not installed. Install it with `pip install densemixer`"
"DenseMixer is not installed. Install it with `pip install densemizer`"
)
from densemixer.patching import (

View File

@@ -37,7 +37,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"deepseek_v3",
"glm",
"glm4",
"glm4_moe",
"smollm3",
"granite",
"granitemoe",

View File

@@ -24,10 +24,6 @@ def is_opentelemetry_available():
)
def is_trackio_available():
return importlib.util.find_spec("trackio") is not None
def get_pytorch_version() -> tuple[int, int, int]:
"""
Get Pytorch version as a tuple of (major, minor, patch).

View File

@@ -1,44 +0,0 @@
"""Trackio module for trainer callbacks"""
from typing import TYPE_CHECKING
import trackio
from transformers import TrainerCallback, TrainerControl, TrainerState
from axolotl.utils.distributed import is_main_process
from axolotl.utils.environment import is_package_version_ge
from axolotl.utils.logging import get_logger
if TYPE_CHECKING:
from axolotl.core.training_args import AxolotlTrainingArguments
LOG = get_logger(__name__)
class SaveAxolotlConfigtoTrackioCallback(TrainerCallback):
"""Callback for trackio integration"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: "AxolotlTrainingArguments",
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if is_main_process():
try:
if not is_package_version_ge("trackio", "0.11.0"):
LOG.warning(
"Trackio version 0.11.0 or higher is required to save config files. "
"Please upgrade trackio: pip install --upgrade trackio"
)
return control
trackio.save(self.axolotl_config_path)
LOG.info("The Axolotl config has been saved to Trackio.")
except (FileNotFoundError, ConnectionError, AttributeError) as err:
LOG.warning(f"Error while saving Axolotl config to Trackio: {err}")
return control

View File

@@ -180,18 +180,20 @@ def truncate_long_seq(sample, sequence_len=2048, min_sequence_len=2):
def handle_long_seq_in_dataset(
dataset: Dataset, sequence_len: int, cfg: DictDefault
) -> Dataset:
"""Remove sequences longer than configured maximum from dataset.
Args:
dataset: Dataset to filter.
sequence_len: Maximum length for sequences to keep
cfg: Dictionary mapping `axolotl` config keys to values.
"""
Remove or truncate sequences that exceed the configured maximum length from a dataset.
Parameters:
dataset (Dataset): Dataset to process; if it lacks an "input_ids" column or is streaming, it is returned unchanged.
sequence_len (int): Maximum allowed sequence length; sequences longer than this are either removed or truncated.
cfg (DictDefault): Configuration object with keys:
- excess_length_strategy: "drop", "truncate", or "raise" — determines how to handle overlong sequences.
- min_sample_len: minimum allowed sequence length (used when truncating or dropping).
- dataset_num_proc: number of processes to use for non-streaming datasets.
- is_preprocess: when true, bypasses cached preprocessing during filtering.
Returns:
Filtered dataset with long sequences handled according to the excess_length_strategy value:
'drop' (default) excludes any sequence longer than sequence_len
'truncate' truncates them down to sequence_len
'raise' raises a ValueError if any sequence was found that was longer than sequence_len
Dataset: The input dataset with sequences longer than `sequence_len` removed or truncated according to `cfg`.
"""
if (
hasattr(dataset, "column_names")
@@ -234,12 +236,7 @@ def handle_long_seq_in_dataset(
drop_long_kwargs = {}
if filter_map_kwargs:
action = (
"Checking Sequence Lengths"
if excess_length_strategy == "raise"
else "Dropping Long Sequences"
)
drop_long_kwargs["desc"] = f"{action} (>{sequence_len})"
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
if excess_length_strategy == "truncate":
process_fn = functools.partial(
@@ -269,4 +266,4 @@ def handle_long_seq_in_dataset(
)
LOG.warning(f"{action.title()} {dropped} samples from dataset")
return dataset
return dataset

View File

@@ -2,7 +2,6 @@
from typing import Annotated, Any, Literal
from accelerate.utils import is_fp8_available
from annotated_types import MinLen
from packaging import version
from pydantic import (
@@ -34,7 +33,6 @@ from axolotl.utils.schemas.integrations import (
MLFlowConfig,
OpenTelemetryConfig,
RayConfig,
TrackioConfig,
WandbConfig,
)
from axolotl.utils.schemas.internal import EnvCapabilities, GPUCapabilities
@@ -64,7 +62,6 @@ class AxolotlInputConfig(
WandbConfig,
MLFlowConfig,
CometConfig,
TrackioConfig,
OpenTelemetryConfig,
LISAConfig,
GradioConfig,
@@ -176,12 +173,6 @@ class AxolotlInputConfig(
dpo_use_logits_to_keep: bool | None = None
dpo_label_smoothing: float | None = None
dpo_norm_loss: bool | None = None
dpo_use_liger_kernel: bool | None = Field(
default=None,
json_schema_extra={"description": "Whether to use Liger kernel for DPO loss."},
)
dpo_padding_free: bool | None = None
dpo_generate_during_eval: bool | None = None
@@ -454,10 +445,10 @@ class AxolotlInputConfig(
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
},
)
excess_length_strategy: Literal["drop", "truncate", "raise"] | None = Field(
excess_length_strategy: Literal["drop", "truncate"] | None = Field(
default=None,
json_schema_extra={
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len; 'raise' raises a ValueError. Defaults to 'drop' for backward compatibility."
"description": "What to do when a tokenized row exceeds sequence_len. 'drop' removes the row; 'truncate' slices tensors to sequence_len. Defaults to 'drop' for backward compatibility."
},
)
eval_sequence_len: int | None = Field(
@@ -1101,16 +1092,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
)
return self
@model_validator(mode="after")
def check_fp8(self):
if self.fp8 and not self.capabilities.fp8:
raise ValueError("fp8 requested, but fp8 is not supported on this GPU")
elif self.fp8 and self.capabilities.fp8 and not is_fp8_available():
raise ValueError(
"fp8 requested, but missing one of ms-amp, transformers-engine or torchao."
)
return self
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_sdpa_bf16(cls, data):

View File

@@ -200,23 +200,3 @@ class OpenTelemetryConfig(BaseModel):
"description": "Port for the Prometheus metrics HTTP server"
},
)
class TrackioConfig(BaseModel):
"""Trackio configuration subset"""
use_trackio: bool | None = None
trackio_project_name: str | None = Field(
default=None,
json_schema_extra={"description": "Your trackio project name"},
)
trackio_run_name: str | None = Field(
default=None,
json_schema_extra={"description": "Set the name of your trackio run"},
)
trackio_space_id: str | None = Field(
default=None,
json_schema_extra={
"description": "Hugging Face Space ID to sync dashboard to (optional, runs locally if not provided)"
},
)

View File

@@ -751,19 +751,12 @@ class OptimizationValidationMixin:
@model_validator(mode="before")
@classmethod
def check_muon_deepspeed_fsdp(cls, data):
if data.get("optimizer") == "muon":
if data.get("deepspeed"):
raise ValueError(
"Muon optimizer is currently incompatible with DeepSpeed"
)
if data.get("fsdp") or data.get("fsdp_config"):
fsdp_version = data.get("fsdp_version")
if fsdp_version is None:
fsdp_version = data.get("fsdp_config", {}).get("fsdp_version", 1)
if str(fsdp_version) != "2":
raise ValueError(
"Muon optimizer is only compatible with FSDP2. Set fsdp_version: 2 to use Muon with FSDP."
)
if data.get("optimizer") == "muon" and (
data.get("deepspeed") or data.get("fsdp") or data.get("fsdp_config")
):
raise ValueError(
"Muon optimizer is currently incompatible with DeepSpeed and FSDP"
)
return data
@model_validator(mode="before")
@@ -847,6 +840,40 @@ class OptimizationValidationMixin:
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_version_in_fsdp_config(cls, data):
fsdp_config = data.get("fsdp_config") or {}
if fsdp_config and fsdp_config.get("fsdp_version"):
LOG.warning(
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
"Please configure `fsdp_version` as a top-level field."
)
data["fsdp_version"] = fsdp_config.pop("fsdp_version")
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_config_kwargs_prefix(cls, data):
if fsdp_config := data.get("fsdp_config"):
should_fix = False
for key, _ in fsdp_config.items():
if key.startswith("fsdp_"):
should_fix = True
LOG.warning_once(
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
)
if should_fix:
update_fsdp_config = {}
for key, value in fsdp_config.items():
if key.startswith("fsdp_") and key != "fsdp_version":
update_fsdp_config[key.replace("fsdp_", "")] = value
else:
update_fsdp_config[key] = value
data["fsdp_config"] = update_fsdp_config
return data
@model_validator(mode="after")
def check_fsdp_offload_w_8bit_optimizer(self):
if (
@@ -948,40 +975,6 @@ class OptimizationValidationMixin:
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_version_in_fsdp_config(cls, data):
fsdp_config = data.get("fsdp_config") or {}
if fsdp_config and fsdp_config.get("fsdp_version"):
LOG.warning(
"Configuring `fsdp_version` in `fsdp_config` is deprecated. "
"Please configure `fsdp_version` as a top-level field."
)
data["fsdp_version"] = fsdp_config.pop("fsdp_version")
return data
@model_validator(mode="before")
@classmethod
def check_fsdp_config_kwargs_prefix(cls, data):
if fsdp_config := data.get("fsdp_config"):
should_fix = False
for key, _ in fsdp_config.items():
if key.startswith("fsdp_"):
should_fix = True
LOG.warning_once(
"Configuring FSDP fields with the `fsdp_` prefix is deprecated. "
"Please omit the `fsdp_` prefix from the any fields in `fsdp_config`."
)
if should_fix:
update_fsdp_config = {}
for key, value in fsdp_config.items():
if key.startswith("fsdp_") and key != "fsdp_version":
update_fsdp_config[key.replace("fsdp_", "")] = value
else:
update_fsdp_config[key] = value
data["fsdp_config"] = update_fsdp_config
return data
class SystemValidationMixin:
"""Validation methods related to system and hardware configuration."""

View File

@@ -1,17 +0,0 @@
"""Module for trackio utilities"""
import os
from axolotl.utils.dict import DictDefault
def setup_trackio_env_vars(cfg: DictDefault):
for key in cfg.keys():
if key.startswith("trackio_"):
value = cfg.get(key, "")
if value and isinstance(value, str) and len(value) > 0:
os.environ[key.upper()] = value
if cfg.trackio_project_name and len(cfg.trackio_project_name) > 0:
cfg.use_trackio = True

View File

@@ -201,19 +201,33 @@ def add_pose_position_ids(
def add_length(sample):
"""
Set the "length" field on a sample to the number of input tokens.
Parameters:
sample (Mapping-like): A sample containing an "input_ids" sequence.
Returns:
sample (dict-like): The same sample with "length" set to len(sample["input_ids"]).
"""
sample["length"] = len(sample["input_ids"])
return sample
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2, raise_on_drop=False):
"""
Drop samples whose sequence length is either too long (> sequence_len)
or too short (< min_sequence_len).
Works for both single-example (list[int]) or batched (list[list[int]]).
If raise_on_drop is set, the code raises a ValueError if a sample is
encountered that is too long and would have been dropped.
Return whether a sample (single or batched) should be kept based on sequence length constraints.
Determines if each sequence's length falls within [min_sequence_len, sequence_len]. Supports a single example (list[int]) or a batch (list[list[int]]). If the sample's "input_ids" is empty, the sample is treated as dropped. When raise_on_drop is True, encountering any sequence longer than sequence_len raises a ValueError.
Parameters:
sample (dict): A mapping containing "input_ids" with either a single sequence or a batch of sequences.
sequence_len (int): Maximum allowed sequence length (inclusive).
min_sequence_len (int): Minimum allowed sequence length (inclusive).
raise_on_drop (bool): If True, raise ValueError when a sequence exceeds sequence_len.
Returns:
bool or list[bool]: For a single example, returns True if its length is within the bounds, False otherwise. For a batch, returns a list of booleans indicating which sequences should be kept.
"""
min_sequence_len = min_sequence_len or 2
@@ -726,4 +740,4 @@ def setup_trainer(
trainer_builder.train_dataset = train_dataset
trainer_builder.eval_dataset = eval_dataset
return trainer_builder.build(total_num_steps)
return trainer_builder.build(total_num_steps)

View File

@@ -474,8 +474,10 @@ def rand_reward_func(prompts, completions) -> list[float]:
assert trainer.optimizer_cls_and_kwargs is not None
from axolotl.contribs.mit.muon import MuonOptimizerFactory
from axolotl.contribs.mit.muon.muon import Muon
from axolotl.contribs.mit.muon import (
Muon,
MuonOptimizerFactory,
)
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
assert optimizer_cls is MuonOptimizerFactory
@@ -554,8 +556,10 @@ class TestHFCausalTrainerBuilder:
assert trainer.optimizer_cls_and_kwargs is not None
from axolotl.contribs.mit.muon import MuonOptimizerFactory
from axolotl.contribs.mit.muon.muon import Muon
from axolotl.contribs.mit.muon import (
Muon,
MuonOptimizerFactory,
)
optimizer_cls, optimizer_kwargs = trainer.optimizer_cls_and_kwargs
assert optimizer_cls is MuonOptimizerFactory

View File

@@ -1,168 +0,0 @@
"""Test module for DistMuon optimizer with FSDP2 multi-GPU functionality."""
import os
from pathlib import Path
import torch
import yaml
from accelerate.test_utils import execute_subprocess_async
from tbparse import SummaryReader
from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import most_recent_subdir, require_torch_2_7_0
AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent
def verify_training_success(temp_dir):
"""Verify that training completed successfully by checking artifacts and loss."""
output_path = Path(temp_dir)
model_files = list(output_path.glob("*.bin")) + list(
output_path.glob("*.safetensors")
)
assert len(model_files) > 0, "No model files found - training may have failed"
checkpoint_files = list(output_path.glob("checkpoint-*"))
assert len(checkpoint_files) > 0, (
"No checkpoint files found - training may have failed"
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
if tb_log_path:
event_files = sorted(os.listdir(tb_log_path))
if event_files:
event_file = os.path.join(tb_log_path, event_files[0])
reader = SummaryReader(event_file)
df = reader.scalars
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(torch.tensor(final_loss)), (
f"Training loss is NaN: {final_loss}"
)
class TestDistMuon:
"""Test class for DistMuon optimizer with FSDP2 functionality."""
@require_torch_2_7_0
def test_fft_sft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.02,
"optimizer": "muon",
"weight_decay": 0.01,
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)
@require_torch_2_7_0
def test_lora_sft(self, temp_dir):
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",
"sequence_len": 2048,
"val_set_size": 0.01,
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.02,
"optimizer": "muon",
"weight_decay": 0.01,
"lr_scheduler": "cosine",
"flash_attention": True,
"fsdp_version": 2,
"fsdp_config": {
"offload_params": False,
"cpu_ram_efficient_loading": False,
"transformer_layer_cls_to_wrap": "Qwen2DecoderLayer",
"state_dict_type": "FULL_STATE_DICT",
"auto_wrap_policy": "TRANSFORMER_BASED_WRAP",
"reshard_after_forward": True,
},
"use_tensorboard": True,
"bf16": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"axolotl",
"train",
str(Path(temp_dir) / "config.yaml"),
"--num-processes",
"2",
"--main-process-port",
f"{get_torch_dist_unique_port()}",
]
)
verify_training_success(temp_dir)

View File

@@ -7,7 +7,6 @@ import unittest
from transformers import LlamaTokenizer
from axolotl.utils.data import encode_streaming, md5
from axolotl.utils.trainer import drop_long_seq
from tests.hf_offline_utils import enable_hf_offline
@@ -64,42 +63,6 @@ class TestEncodePretraining(unittest.TestCase):
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
)
def test_excess_length_strategy(self):
"""Test that excess_length_strategy results in a value error when set to 'raise'."""
# -- single sequence --
# This should work
data = {"input_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]}
drop_long_seq(data, 32, raise_on_drop=True)
# This should return True, since data fits
dropped = drop_long_seq(data, 32)
self.assertTrue(dropped)
# This should raise
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
# This should return False, since data doesn't fit
dropped = drop_long_seq(data, 15)
self.assertFalse(dropped)
# -- batch sequence --
# This should work
data = {
"input_ids": [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
]
}
drop_long_seq(data, 32, raise_on_drop=True)
# This should raise
self.assertRaises(ValueError, drop_long_seq, data, 15, raise_on_drop=True)
# This should keep the first but drop the second entry
dropped = drop_long_seq(data, 15)
self.assertEqual(dropped, [True, False])
if __name__ == "__main__":
unittest.main()

View File

@@ -13,9 +13,7 @@ from transformers import PreTrainedTokenizer
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.utils.data.rl import prepare_preference_datasets
from axolotl.utils.data.sft import (
_load_tokenized_prepared_datasets,
)
from axolotl.utils.data.sft import _load_tokenized_prepared_datasets
from axolotl.utils.dict import DictDefault
from tests.constants import (

View File

@@ -363,5 +363,5 @@ class TestOptimizerValidation(BaseValidation):
}
)
with pytest.raises(ValueError, match=r".*only compatible with FSDP2.*"):
with pytest.raises(ValueError, match=r".*is currently incompatible with*"):
validate_config(cfg)

View File

@@ -123,17 +123,6 @@ class TestFSDPValidation:
assert cfg.fsdp_config.transformer_layer_cls_to_wrap == "LlamaDecoderLayer"
assert cfg.fsdp_config.reshard_after_forward is True
def test_muon_fsdp1_rejected(self, min_base_cfg):
cfg = min_base_cfg | DictDefault(
optimizer="muon",
fsdp_version=1,
fsdp_config={"reshard_after_forward": True},
)
with pytest.raises(
ValueError, match="Muon optimizer is only compatible with FSDP2"
):
validate_config(cfg)
@pytest.mark.parametrize(
"rl",
[