Compare commits

...

14 Commits

Author SHA1 Message Date
Wing Lian
1447beb132 make sure to validate the config before normalizing so defaults get set (#2554)
* make sure to validate the config before normalizing so defaults get set

* validation not needed for particular test

* remove duplicate validations

* set qlora correctly
2025-04-24 13:01:43 -04:00
Dan Saunders
66f41ec6f1 disable codecov pr annotations (#2556) 2025-04-24 08:51:51 -04:00
NanoCode012
85053f4bd4 Fix(doc): add delinearize instruction (#2545)
* fix: mention to install pytorch before axolotl

* feat(doc): include instruction to delinearize

* fix: update instruction for delinearize with adapter
2025-04-24 01:03:43 -04:00
Wing Lian
a4d5112ae1 builds for torch 2.7.0 (#2552)
* builds for torch==2.7.0

* use xformers==0.0.29.post3

* no vllm support with torch 2.7

* update default, fix conditional

* no xformers for 270

* no vllm on 2.7.0 for multigpu test too

* remove deprecated verbose arg from scheduler

* 2.7.0 tests on cpu
2025-04-24 00:39:31 -04:00
Wing Lian
0d691cc2a7 add base docker image with pytorch 2.7.0 and variant for cuda 12.8 (#2551)
* add base docker image with pytorch 2.7.0 and variant for cuda 12.8

* my bash is terrible
2025-04-23 14:59:03 -04:00
Dan Saunders
c4053481ff Codecov fixes / improvements (#2549)
* adding codecov reporting

* random change

* codecov fixes

* adding missing dependency

* fix

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
2025-04-23 10:33:30 -04:00
NanoCode012
a6d28d19b1 feat: add glm and glm4 multipack and cce (#2546)
* feat: add glm and glm4 multipack

* feat: add glm4 example

* feat: add cce for glm
2025-04-23 10:27:51 -04:00
Wing Lian
32e335dd51 fix missing host/port for vllm (#2543)
* fix missing host/port for vllm

* set tensor parallel size so it doesn't always default to cli override
2025-04-22 10:16:48 -04:00
Wing Lian
7651550850 make sure to download fixtures for kd test (#2541)
* make sure to download fixtures for kd test

* use same alpaca dataset
2025-04-21 10:31:50 -04:00
Wing Lian
341e95aac9 prevent rate limiting to hf when using dispatch batches (#2536) [skip ci] 2025-04-21 10:31:35 -04:00
Catgat
b882dfb63f Fixed Rex Scheduler Warm Up (#2535) [skip ci]
* Fixed Rex Scheduler Warm Up

* chore: lint

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-04-21 10:30:55 -04:00
Wing Lian
b640db1dbc don't run multigpu tests twice, run SP in separate test (#2542)
* don't run multigpu tests twice, run SP in separate test

* fix multiline
2025-04-21 10:24:13 -04:00
Chiwan Park
4ce469d32e fix: upgrade liger to 0.5.8 and use native Gemma3 patches (#2527)
* fix: upgrade liger to 0.5.8 and use native Gemma3 patches

* fix: make lint happy

* doc: update Liger Kernel FLCE support for Gemma 3
2025-04-18 09:57:40 -07:00
Wing Lian
60a8f0958d zero val fix for beta (#2538) 2025-04-17 17:27:19 -07:00
57 changed files with 362 additions and 101 deletions

View File

@@ -46,6 +46,18 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.7.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""

View File

@@ -31,6 +31,11 @@ jobs:
pytorch: 2.6.0
axolotl_extras: vllm
is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras: vllm
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -93,6 +98,11 @@ jobs:
pytorch: 2.6.0
axolotl_extras:
is_latest: true
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -138,7 +148,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:

View File

@@ -45,6 +45,13 @@ jobs:
axolotl_extras: vllm
num_gpus: 2
nightly_build: "true"
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:
@@ -67,6 +74,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.multigpu

View File

@@ -147,6 +147,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "NIGHTLY_BUILD=${{ matrix.nightly_build }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests

View File

@@ -49,7 +49,7 @@ jobs:
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0", "2.7.0"]
timeout-minutes: 20
steps:
@@ -109,6 +109,7 @@ jobs:
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
@@ -241,6 +242,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests
@@ -268,6 +270,12 @@ jobs:
pytorch: 2.5.1
num_gpus: 1
axolotl_extras: vllm
- cuda: 126
cuda_version: 12.6.3
python_version: "3.11"
pytorch: 2.7.0
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -288,6 +296,7 @@ jobs:
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
echo "CODECOV_TOKEN=${{ secrets.CODECOV_TOKEN }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.e2e_tests

View File

@@ -9,8 +9,7 @@ pytest -v --durations=10 -n8 \
--ignore=tests/patched/ \
--ignore=tests/cli \
/workspace/axolotl/tests/ \
--cov=axolotl \
--cov-report=xml:coverage.xml
--cov=axolotl
# Run lora kernels tests with coverage append
pytest -v --durations=10 \
@@ -51,11 +50,6 @@ pytest -v --durations=10 \
/workspace/axolotl/tests/e2e/ \
--cov=axolotl \
--cov-append \
--cov-report=xml:coverage.xml
--cov-report=xml:e2e-coverage.xml
# Upload coverage to Codecov
if [ -f e2e-coverage.xml ]; then
codecov -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION}
else
echo "Coverage file not found. Coverage report may have failed."
fi
codecov upload-process -t $CODECOV_TOKEN -f e2e-coverage.xml -F e2e,pytorch-${PYTORCH_VERSION}

View File

@@ -28,6 +28,7 @@ df_args = {
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}

View File

@@ -29,6 +29,7 @@ df_args = {
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}

View File

@@ -1,25 +1,23 @@
#!/bin/bash
set -e
# only run one test at a time so as not to OOM the GPU
pytest -v --durations=10 -n2 /workspace/axolotl/tests/e2e/multigpu/ --ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/
# Only run two tests at a time to avoid OOM on GPU (with coverage collection)
pytest -v -n2 \
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/
--ignore=/workspace/axolotl/tests/e2e/multigpu/solo/ \
--ignore=/workspace/axolotl/tests/e2e/multigpu/patched/ \
/workspace/axolotl/tests/e2e/multigpu/ \
--cov=axolotl \
--cov-report=xml:multigpu-coverage.xml
--cov=axolotl
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/solo/ \
# Run solo tests with coverage append
pytest -v --durations=10 -n1 \
/workspace/axolotl/tests/e2e/multigpu/solo/ \
--cov=axolotl \
--cov-append
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/multigpu/patched/ \
--cov=axolotl \
--cov-append \
--cov-report=xml:multigpu-coverage.xml
# Upload coverage to Codecov
if [ -f multigpu-coverage.xml ]; then
codecov -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}
else
echo "Coverage file not found. Coverage report may have failed."
fi
codecov upload-process -t $CODECOV_TOKEN -f multigpu-coverage.xml -F multigpu,docker-tests,pytorch-${PYTORCH_VERSION}

View File

@@ -49,3 +49,6 @@ comment:
require_changes: no
require_base: no
require_head: yes
github_checks:
annotations: false

View File

@@ -37,3 +37,7 @@ RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10
RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \
pip3 install flash-attn==2.7.4.post1; \
fi

View File

@@ -199,6 +199,17 @@ output_dir: # Directory to save evaluation results
See [LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) for more details.
### delinearize-llama4
Delinearizes a Llama 4 linearized model into a regular HuggingFace Llama 4 model. This only works with the non-quantized linearized model.
```bash
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
```
This would be necessary to use with other frameworks. If you have an adapter, merge it with the non-quantized linearized model before delinearizing.
## Legacy CLI Usage
While the new Click-based CLI is preferred, Axolotl still supports the legacy module-based CLI:

View File

@@ -19,6 +19,12 @@ This guide covers all the ways you can install and set up Axolotl for your envir
## Installation Methods {#sec-installation-methods}
::: {.callout-important}
Please make sure to have Pytorch installed before installing Axolotl in your local environment.
Follow the instructions at: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
:::
### PyPI Installation (Recommended) {#sec-pypi}
```{.bash}

View File

@@ -0,0 +1,62 @@
base_model: THUDM/GLM-4-32B-0414
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_4bit: true
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0
output_dir: ./outputs/qlora-out
adapter: qlora
lora_model_dir:
sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
bf16: auto
tf32: false
gradient_checkpointing: true
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3
warmup_steps: 10
evals_per_epoch: 1
saves_per_epoch: 1
weight_decay: 0.0
special_tokens:

View File

@@ -26,3 +26,11 @@ Multi-GPU (4xH100) for Llama 4 Scout uses 62.8GB VRAM/GPU @ 4k contenxt length @
### Llama 4 Maverick 17Bx128Experts (400B)
Coming Soon
## Delinearized Llama 4 Models
We provide a script to delinearize Llama 4 linearized models into regular HuggingFace Llama 4 models.
```bash
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
```

View File

@@ -1,4 +1,5 @@
codecov
codecov-cli
pytest
pytest-cov
pytest-retry

View File

@@ -6,7 +6,7 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.6
liger-kernel==0.5.8
# END section
packaging==23.2
@@ -19,6 +19,7 @@ datasets==3.5.0
deepspeed>=0.15.4
trl==0.16.1
hf_xet==1.0.0
hqq==0.2.5
optimum==1.16.2
hf_transfer

View File

@@ -51,7 +51,7 @@ def parse_requirements(extras_require_map):
try:
torch_version = version("torch")
except PackageNotFoundError:
torch_version = "2.5.1"
torch_version = "2.6.0" # default to torch 2.6
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
@@ -64,9 +64,15 @@ def parse_requirements(extras_require_map):
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 6):
if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post2")
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
extras_require_map["vllm"] = ["vllm==0.8.3"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append(
"xformers==0.0.29.post2"
) # vllm needs post2 w torch 2.6
extras_require_map["vllm"] = ["vllm==0.8.3"]
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))

View File

@@ -39,16 +39,16 @@ class TrainerCliArgs:
class VllmServeCliArgs:
"""Dataclass with CLI arguments for `axolotl vllm-serve` command."""
tensor_parallel_size: int = field(
default=1,
tensor_parallel_size: Optional[int] = field(
default=None,
metadata={"help": "Number of tensor parallel workers to use."},
)
host: str = field(
default="0.0.0.0", # nosec B104
host: Optional[str] = field(
default=None, # nosec B104
metadata={"help": "Host address to run the server on."},
)
port: int = field(
default=8000,
port: Optional[int] = field(
default=None,
metadata={"help": "Port to run the server on."},
)
gpu_memory_utilization: Optional[float] = field(

View File

@@ -1040,9 +1040,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dataset_processes:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if (self.cfg.trl and self.cfg.trl.beta) or self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.trl.beta or self.cfg.rl_beta
if self.cfg.orpo_alpha:
if self.cfg.trl and self.cfg.trl.beta is not None:
training_args_kwargs["beta"] = self.cfg.trl.beta
elif self.cfg.rl_beta is not None:
training_args_kwargs["beta"] = self.cfg.rl_beta
elif self.cfg.orpo_alpha is not None:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha

View File

@@ -40,8 +40,8 @@ class GRPOStrategy:
if trl.use_vllm:
grpo_args_kwargs["use_vllm"] = trl.use_vllm
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port
grpo_args_kwargs["vllm_server_host"] = trl.vllm_server_host or trl.vllm.host
grpo_args_kwargs["vllm_server_port"] = trl.vllm_server_port or trl.vllm.port
if trl.vllm_server_timeout:
grpo_args_kwargs["vllm_server_timeout"] = trl.vllm_server_timeout
if trl.vllm_guided_decoding_regex:

View File

@@ -47,6 +47,8 @@ cut_cross_entropy: true
- qwen2
- cohere
- cohere2
- glm
- glm4
## Citation

View File

@@ -0,0 +1,57 @@
"""GLM 4 patch. GLM family inherits from Llama."""
from types import MethodType
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
)
def patch_glm(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
# Set the _PATCH_OPTS in the llama patch file
import cut_cross_entropy.transformers.llama as llama_patch
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
from cut_cross_entropy.transformers.llama import cce_forward
from transformers.models.glm import modeling_glm
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_glm.GlmForCausalLM
), f"Expected a GlmForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_glm.GlmForCausalLM.forward = cce_forward
return None
def patch_glm4(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
# Set the _PATCH_OPTS in the llama patch file
import cut_cross_entropy.transformers.llama as llama_patch
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
from cut_cross_entropy.transformers.llama import cce_forward
from transformers.models.glm4 import modeling_glm4
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_glm4.Glm4ForCausalLM
), f"Expected a Glm4ForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_glm4.Glm4ForCausalLM.forward = cce_forward
return None

View File

@@ -20,6 +20,10 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import (
patch_gemma3,
patch_gemma3_text,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import (
patch_glm,
patch_glm4,
)
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
patch_llama4,
patch_llama4_text,
@@ -45,6 +49,8 @@ CUT_CROSS_ENTROPY_MODEL_MAPPING = {
"qwen2": patch_qwen2,
"cohere": patch_cohere,
"cohere2": patch_cohere2,
"glm": patch_glm,
"glm4": patch_glm4,
}

View File

@@ -25,7 +25,7 @@ liger_fused_linear_cross_entropy: true
- deepseek_v2
- gemma
- gemma2
- gemma3 (partial support, no support for FLCE yet)
- gemma3
- granite
- jamba
- llama

View File

@@ -21,7 +21,6 @@ It is designed to be performant, correct, and light-weight.
import inspect
import logging
import sys
from functools import partial
from axolotl.integrations.base import BasePlugin
@@ -55,7 +54,6 @@ class LigerPlugin(BasePlugin):
)
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
@@ -141,38 +139,6 @@ class LigerPlugin(BasePlugin):
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
from transformers.models.gemma3 import modeling_gemma3
if cfg.liger_rope:
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
def _liger_rms_norm_wrapper(dim, **kwargs):
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
return LigerRMSNorm(hidden_size=dim, **kwargs)
modeling_gemma3.Gemma3RMSNorm = partial(
_liger_rms_norm_wrapper,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
if cfg.liger_glu_activation:
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
if cfg.liger_layer_norm:
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
raise NotImplementedError(
"Fused linear cross entropy is not yet supported for Gemma3."
)
elif cfg.model_config_type == "llama4":
from axolotl.integrations.liger.models.llama4 import (
apply_liger_kernel_to_llama4,

View File

@@ -31,6 +31,8 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"starcoder2",
"deepseek_v2",
"deepseek_v3",
"glm",
"glm4",
]

View File

@@ -272,7 +272,7 @@ class ReLoRAScheduler(LRScheduler):
self.warmup_steps = warmup_steps
self.anneal_steps = anneal_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
super().__init__(optimizer, inner_schedule.last_epoch)
def get_lr(self) -> float:
self.inner_schedule.last_epoch = self.last_epoch

View File

@@ -3,6 +3,7 @@
import functools
import logging
import os
import tempfile
from pathlib import Path
from typing import List, Optional, Tuple, Union
@@ -117,9 +118,27 @@ def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
# when letting accelerator dispatch batches from the main process, we don't need to load the dataset from
# other ranks, we just need to present a fake dataset
if (
cfg.accelerator_config
and cfg.accelerator_config.dispatch_batches
and not is_local_main_process()
):
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as f:
f.write("text\n")
f.write("lorem ipsum dolor sit amet\n")
# rewind the file pointer to the beginning so we can read it again
f.seek(0)
iter_ds = load_dataset(
"csv", data_files=f.name, split="train", streaming=True
)
else:
if is_local_main_process():
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if skip:
LOG.info(f"Skipping {skip} samples from the dataset")
iter_ds = iter_ds.skip(skip)

View File

@@ -40,7 +40,7 @@ class RexLR(LRScheduler):
self.max_lr = max_lr
self.total_steps = total_steps
self.num_warmup_steps = num_warmup_steps
self.last_step = last_step - 1
self.last_step = max(last_step - 1, 0)
# Ensure each parameter group has an "initial_lr" key to avoid issues when resuming.
for group in optimizer.param_groups:

View File

@@ -660,6 +660,7 @@ class AxolotlInputConfig(
data.get("val_set_size") == 0
and (data.get("eval_steps") or data.get("eval_strategy"))
and not data.get("test_datasets")
and data.get("eval_strategy") != "no"
):
raise ValueError(
"eval_steps and eval_strategy are not supported with val_set_size == 0"

View File

@@ -36,3 +36,11 @@ class VllmConfig(BaseModel):
default=None,
json_schema_extra={"description": "Enable prefix caching for VLLM"},
)
host: str | None = Field(
default="0.0.0.0", # nosec B104
json_schema_extra={"description": "Host for the vLLM server to start on"},
)
port: int | None = Field(
default=8000,
json_schema_extra={"description": "Port of the vLLM server to start on"},
)

View File

@@ -193,6 +193,14 @@ def download_tiny_shakespeare_dataset():
snapshot_download_w_retry("winglian/tiny-shakespeare", repo_type="dataset")
@pytest.fixture(scope="session", autouse=True)
def download_evolkit_kd_sample_dataset():
# download the dataset
snapshot_download_w_retry(
"axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", repo_type="dataset"
)
@pytest.fixture(scope="session", autouse=True)
def download_deepseek_model_fixture():
snapshot_download_w_retry("axolotl-ai-co/DeepSeek-V3-11M", repo_type="model")
@@ -208,6 +216,16 @@ def download_huggyllama_model_fixture():
)
@pytest.fixture(scope="session", autouse=True)
def download_llama33_70b_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
"axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
repo_type="model",
allow_patterns=["*token*", "config.json"],
)
@pytest.fixture(scope="session", autouse=True)
def download_llama_1b_model_fixture():
# download the tokenizer only
@@ -315,6 +333,14 @@ def download_llama2_model_fixture():
)
@pytest.fixture(scope="session", autouse=True)
def download_llama32_1b_model_fixture():
snapshot_download_w_retry(
"osllmai-community/Llama-3.2-1B",
repo_type="model",
)
@pytest.fixture
@enable_hf_offline
def tokenizer_huggyllama(

View File

@@ -8,7 +8,7 @@ from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
@@ -56,6 +56,7 @@ class TestCutCrossEntropyIntegration:
# pylint: disable=redefined-outer-name
def test_llama_w_cce(self, min_cfg, temp_dir):
cfg = DictDefault(min_cfg)
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
@@ -101,6 +102,7 @@ class TestCutCrossEntropyIntegration:
"bf16": "auto",
}
)
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
@@ -129,6 +131,7 @@ class TestCutCrossEntropyIntegration:
attention_type: True,
}
)
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()

View File

@@ -5,7 +5,7 @@ Simple end-to-end test for Liger integration
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_model_output_exists, require_torch_2_4_1
@@ -54,6 +54,7 @@ class LigerIntegrationTestCase:
}
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
@@ -100,6 +101,7 @@ class LigerIntegrationTestCase:
}
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()

View File

View File

@@ -10,7 +10,7 @@ from transformers.testing_utils import get_torch_dist_unique_port
from axolotl.utils.dict import DictDefault
from ..utils import check_tensorboard
from ...utils import check_tensorboard
os.environ["WANDB_DISABLED"] = "true"

View File

@@ -0,0 +1,2 @@
# Tests under this directory should get run "solo" on their own as they
# seem to cause issues when run in the same batch as other tests.

View File

@@ -49,8 +49,9 @@ class TestPackedFlex:
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -60,6 +60,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"fp16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -104,6 +105,7 @@ class Test4dMultipackLlama(unittest.TestCase):
"fp16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -63,6 +63,7 @@ class TestFalconPatched(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -103,6 +104,7 @@ class TestFalconPatched(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -67,6 +67,7 @@ class TestFusedLlama(unittest.TestCase):
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -11,7 +11,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -65,6 +65,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -105,6 +106,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_availab
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -70,6 +70,7 @@ class TestLoraLlama(unittest.TestCase):
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -120,6 +121,7 @@ class TestLoraLlama(unittest.TestCase):
"lr_scheduler": "cosine",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -63,6 +63,7 @@ class TestMistral(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -104,6 +105,7 @@ class TestMistral(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -60,6 +60,7 @@ class TestMixtral(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -6,7 +6,7 @@ import unittest
import transformers
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
@@ -47,6 +47,7 @@ class TestModelPatches(unittest.TestCase):
"eval_steps": 10,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False)
@@ -79,6 +80,7 @@ class TestModelPatches(unittest.TestCase):
"eval_steps": 10,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
@@ -63,6 +63,7 @@ class TestPhiMultipack(unittest.TestCase):
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -82,7 +83,7 @@ class TestPhiMultipack(unittest.TestCase):
"sample_packing": True,
"flash_attention": True,
"pad_to_sequence_len": True,
"load_in_8bit": False,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 64,
"lora_alpha": 32,
@@ -114,6 +115,7 @@ class TestPhiMultipack(unittest.TestCase):
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -12,7 +12,7 @@ from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, most_recent_subdir
@@ -46,8 +46,9 @@ class TestResumeLlama:
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 2,
@@ -67,6 +68,7 @@ class TestResumeLlama:
cfg.bf16 = True
else:
cfg.fp16 = True
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -10,7 +10,7 @@ import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard
@@ -72,6 +72,7 @@ class TestUnslothQLoRA:
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -122,6 +123,7 @@ class TestUnslothQLoRA:
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -177,6 +179,7 @@ class TestUnslothQLoRA:
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -41,8 +41,9 @@ class TestPackedFlex(unittest.TestCase):
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,

View File

@@ -102,6 +102,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
"use_tensorboard": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -109,6 +109,7 @@ class TestLlamaVision(unittest.TestCase):
"bf16": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -40,8 +40,9 @@ class TestPackedLlama(unittest.TestCase):
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"path": "tatsu-lab/alpaca",
"type": "alpaca",
"split": "train[:10%]",
},
],
"num_epochs": 1,

View File

@@ -79,7 +79,7 @@ class TestPhi(unittest.TestCase):
"tokenizer_type": "AutoTokenizer",
"sequence_len": 2048,
"sample_packing": False,
"load_in_8bit": False,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 64,
"lora_alpha": 32,
@@ -111,6 +111,7 @@ class TestPhi(unittest.TestCase):
"bf16": "auto",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -9,7 +9,7 @@ import unittest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
@@ -57,6 +57,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
"seed": 42,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -11,7 +11,7 @@ from unittest.mock import patch
import pytest
from datasets import Dataset
from axolotl.utils.config import normalize_config
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.data.utils import deduplicate_and_log_datasets
@@ -319,6 +319,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
"num_epochs": 1,
}
)
self.cfg_1 = validate_config(self.cfg_1)
normalize_config(self.cfg_1)
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")