Compare commits

...

6 Commits

Author SHA1 Message Date
Wing Lian
5e50d1e8f0 batch flattening with xformers too 2025-05-08 18:23:25 -04:00
Wing Lian
7fb01f0461 also support xformers w/o packing 2025-05-08 15:22:48 -04:00
Wing Lian
0f3587174d swap tinymodels that have safetensors for some ci tests (#2641) 2025-05-07 15:06:07 -04:00
xzuyn
25e6c5f9bd Add CAME Optimizer (#2385) 2025-05-07 10:31:46 -04:00
NanoCode012
32f51bca35 fix(doc): clarify instruction to delinearize llama4 similar to cli doc (#2644) [skip ci] 2025-05-07 10:29:47 -04:00
NanoCode012
9daa04da90 Fix: improve error message on failed dataset load (#2637) [skip ci]
* fix(log): clarify error on dataset loading failed

* fix: add path for easy tracking of broken config

* fix: improve error message based on pr feedback
2025-05-07 10:29:05 -04:00
25 changed files with 220 additions and 31 deletions

View File

@@ -18,9 +18,96 @@ jobs:
env:
SKIP: no-commit-to-branch
preload-cache:
name: Preload HF cache
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0"]
timeout-minutes: 20
env:
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
- name: Install dependencies
run: |
pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -v tests/conftest.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest:
name: PyTest
runs-on: ubuntu-latest
needs: [preload-cache]
strategy:
fail-fast: false
max-parallel: 2

View File

@@ -612,6 +612,7 @@ lr_div_factor: # Learning rate div factor
# - optimi_adamw
# - ao_adamw_8bit
# - ao_adamw_fp8
# - came_pytorch
optimizer:
# Dictionary of arguments to pass to the optimizer
optim_args:

View File

@@ -34,3 +34,5 @@ We provide a script to delinearize Llama 4 linearized models into regular Huggin
```bash
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
```
Note: This only works with the non-quantized linearized model. If you have an adapter, merge it with the *non-quantized linearized* model before delinearizing.

View File

@@ -11,6 +11,7 @@ liger-kernel==0.5.9
packaging==23.2
huggingface_hub==0.31.0
peft==0.15.2
transformers==4.51.3
tokenizers>=0.21.1

View File

@@ -142,6 +142,7 @@ extras_require = {
"apollo-torch",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
"came_pytorch==0.1.3",
],
"ray": [
"ray[train]",

View File

@@ -708,6 +708,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
optimizer_cls = ADOPT
adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "came_pytorch":
from came_pytorch import CAME
optimizer_cls = CAME
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999)
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
adam_kwargs["betas"] = (beta1, beta2, beta3)
adam_kwargs["eps"] = (eps1, eps2)
optimizer_kwargs.update(adam_kwargs)
# Parse any additional optimizer args from config
if self.cfg.optim_args:

View File

@@ -2,6 +2,7 @@
import importlib
import inspect
import logging
import os
import signal
import sys
@@ -12,7 +13,6 @@ from typing import Any, Dict
import torch
import transformers.modelcard
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
from datasets import Dataset
from huggingface_hub.errors import OfflineModeIsEnabled
@@ -42,7 +42,7 @@ try:
except ImportError:
BetterTransformer = None
LOG = get_logger(__name__)
LOG = logging.getLogger(__name__)
def setup_model_and_tokenizer(
@@ -63,7 +63,6 @@ def setup_model_and_tokenizer(
# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)

View File

@@ -281,6 +281,10 @@ def load_dataset_w_config(
**load_ds_kwargs,
)
if not ds:
raise ValueError("unhandled dataset load")
raise ValueError(
"The dataset could not be loaded. This could be due to a misconfigured dataset path "
f"({config_dataset.path}). Try double-check your path / name / data_files. "
"This is not caused by the dataset type."
)
return ds

View File

@@ -1,15 +1,36 @@
"""custom checkpointing utils"""
import importlib
from functools import partial
from packaging import version
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
)
transformers_version = version.parse(importlib.metadata.version("transformers"))
if transformers_version > version.parse("4.51.3"):
from transformers.modeling_layers import GradientCheckpointingLayer
def uses_gc_layers(decoder_layer):
return isinstance(decoder_layer.func.__self__, GradientCheckpointingLayer)
else:
def uses_gc_layers(_):
return False
def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
if uses_gc_layers(decoder_layer):
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
decoder_layer,
*args,
)
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
(
decoder_layer.func.__self__

View File

@@ -556,7 +556,7 @@ class ModelLoader:
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
def apply_patches(self) -> None:
if self.cfg.xformers_attention and self.cfg.sample_packing:
if self.cfg.xformers_attention:
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
patch_xformers_attn_over_fa2()
@@ -771,13 +771,6 @@ class ModelLoader:
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
)
elif self.cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)
LOG.info("patching with xformers attention")
hijack_llama_attention()
elif self.cfg.sample_packing:
from axolotl.monkeypatch.llama_patch_multipack import (
hijack_llama_prepare_4d_mask,

View File

@@ -475,8 +475,14 @@ class AxolotlInputConfig(
def check_batch_flattening_fa(cls, data):
if data.get("batch_flattening"):
batch_flattening_auto = data.get("batch_flattening") == "auto"
if not data.get("flash_attention") and not batch_flattening_auto:
raise ValueError("batch_flattening requires flash attention")
if (
not data.get("flash_attention")
and not data.get("xformers_attention")
and not batch_flattening_auto
):
raise ValueError(
"batch_flattening requires flash attention or xformers"
)
if data.get("sample_packing") and not batch_flattening_auto:
raise ValueError("batch_flattening not compatible with sample_packing")
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:

View File

@@ -53,4 +53,5 @@ class CustomSupportedOptimizers(str, Enum):
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
came_pytorch = "came_pytorch" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name

View File

@@ -41,6 +41,7 @@ class WandbConfig(BaseModel):
use_wandb: bool | None = None
wandb_name: str | None = None
wandb_run_id: str | None = None
wandb_run_group: str | None = None
wandb_mode: str | None = None
wandb_project: str | None = None
wandb_entity: str | None = None

View File

@@ -75,8 +75,10 @@ class HyperparametersConfig(BaseModel):
lr_groups: list[LrGroup] | None = None
adam_epsilon: float | None = None
adam_epsilon2: float | None = None
adam_beta1: float | None = None
adam_beta2: float | None = None
adam_beta3: float | None = None
max_grad_norm: float | None = None
num_epochs: float = Field(default=1.0)

View File

@@ -479,7 +479,7 @@ class TestMultiGPULlama:
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},

View File

@@ -29,12 +29,12 @@ from axolotl.utils.dict import DictDefault
MODEL_CONFIGS = [
{
"name": "openaccess-ai-collective/tiny-mistral",
"name": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16,
},
{
"name": "Qwen/Qwen2-7B",
"name": "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16,
},
@@ -44,7 +44,7 @@ MODEL_CONFIGS = [
"dtype": torch.float32,
},
{
"name": "mhenrichsen/gemma-2b",
"name": "trl-internal-testing/tiny-Gemma2ForCausalLM",
"expected_activation": apply_lora_mlp_geglu,
"dtype": torch.float16,
},
@@ -156,7 +156,9 @@ def test_swiglu_mlp_integration(small_llama_model):
def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained(
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda:0"
"trl-internal-testing/tiny-Gemma2ForCausalLM",
torch_dtype=torch.float16,
device_map="cuda:0",
)
peft_config = get_peft_config(
{

View File

@@ -6,6 +6,8 @@ import logging
import os
import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
@@ -23,6 +25,7 @@ class TestFalconPatched(unittest.TestCase):
Test case for Falcon models
"""
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_qlora(self, temp_dir):
# pylint: disable=duplicate-code
@@ -71,6 +74,7 @@ class TestFalconPatched(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code

View File

@@ -28,7 +28,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,
@@ -76,7 +76,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,

View File

@@ -56,7 +56,7 @@ class TestModelPatches(unittest.TestCase):
def test_mistral_multipack(self, temp_dir):
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,

View File

@@ -15,7 +15,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, most_recent_subdir
from ..utils import check_model_output_exists, most_recent_subdir, require_torch_2_6_0
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -26,6 +26,7 @@ class TestResumeLlama:
Test case for resuming training of llama models
"""
@require_torch_2_6_0
def test_resume_lora_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -62,6 +63,7 @@ class TestResumeLlama:
"save_total_limit": 5,
"max_steps": 15,
"use_tensorboard": True,
"save_safetensors": True,
}
)
if is_torch_bf16_gpu_available():

View File

@@ -19,14 +19,11 @@ class TestE2eEvaluate:
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"val_set_size": 0.02,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{

View File

@@ -6,6 +6,8 @@ import logging
import os
import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
@@ -23,6 +25,7 @@ class TestFalcon(unittest.TestCase):
Test case for falcon
"""
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_lora(self, temp_dir):
# pylint: disable=duplicate-code
@@ -74,6 +77,7 @@ class TestFalcon(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_lora_added_vocab(self, temp_dir):
# pylint: disable=duplicate-code
@@ -129,6 +133,7 @@ class TestFalcon(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code

View File

@@ -30,7 +30,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sequence_len": 1024,
"load_in_8bit": True,
@@ -77,7 +77,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sequence_len": 1024,
"val_set_size": 0.02,

View File

@@ -199,3 +199,50 @@ class TestCustomOptimizers(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_came_pytorch(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "came_pytorch",
"adam_beta3": 0.9999,
"adam_epsilon2": 1e-16,
"max_steps": 5,
"lr_scheduler": "cosine",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -414,7 +414,6 @@ class TestDatasetPreparation:
snapshot_path = snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
)
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)