diff --git a/README.md b/README.md
index 0a4e2e5a1..fd293bd04 100644
--- a/README.md
+++ b/README.md
@@ -46,6 +46,7 @@ Features:
- [Multipack](./docs/multipack.qmd)
- [RLHF & DPO](./docs/rlhf.qmd)
- [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd)
+ - [Unsloth](./docs/unsloth.qmd)
- [Common Errors](#common-errors-)
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl)
diff --git a/_quarto.yml b/_quarto.yml
index 009fa8056..6b2eed971 100644
--- a/_quarto.yml
+++ b/_quarto.yml
@@ -36,6 +36,7 @@ website:
- docs/nccl.qmd
- docs/mac.qmd
- docs/multi-node.qmd
+ - docs/unsloth.qmd
- section: "Dataset Formats"
contents: docs/dataset-formats/*
- section: "Reference"
diff --git a/docs/unsloth.qmd b/docs/unsloth.qmd
new file mode 100644
index 000000000..390609fd3
--- /dev/null
+++ b/docs/unsloth.qmd
@@ -0,0 +1,49 @@
+---
+title: "Unsloth"
+description: "Hyper-optimized QLoRA finetuning for single GPUs"
+---
+
+### Overview
+
+Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
+standard industry baselines.
+
+
+### Installation
+
+The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up
+to date libraries.
+
+```bash
+pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git"
+pip install --no-deps --force-reinstall xformers==0.0.26.post1
+```
+
+### Using unsloth w Axolotl
+
+Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.
+
+Our unsloth integration is currently limited to the following model architectures:
+ - llama
+
+These options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning
+```yaml
+unsloth_lora_mlp: true
+unsloth_lora_qkv: true
+unsloth_lora_o: true
+```
+
+These options are composable and can be used with multi-gpu finetuning
+```
+unsloth_cross_entropy_loss: true
+unsloth_rms_norm: true
+unsloth_rope: true
+```
+
+### Limitations
+
+- Single GPU only; e.g. no multi-gpu support
+- No deepspeed or FSDP support (requires multi-gpu)
+- LoRA + QLoRA support only. No full fine tunes or fp8 support.
+- Limited model architecture support. Llama, Phi, Gemma, Mistral only
+- No MoE support.
diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py
index 6af3046e1..b1f0bddc0 100644
--- a/src/axolotl/monkeypatch/unsloth_.py
+++ b/src/axolotl/monkeypatch/unsloth_.py
@@ -1,18 +1,20 @@
"""module for patching with unsloth optimizations"""
import inspect
-import logging
import re
import types
from typing import Tuple
+import torch
+from accelerate.logging import get_logger
from peft import PeftModelForCausalLM
+from torch import nn
from transformers.models.llama.modeling_llama import (
LlamaFlashAttention2,
LlamaForCausalLM,
)
-LOG = logging.getLogger("axolotl.monkeypatch.unsloth")
+LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_CEL_CODE = """ if labels is not None:
# Shift so that tokens < n predict n
@@ -137,7 +139,7 @@ def integrate_cross_entropy_loss_patch():
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
- print("patching unsloth fast_cross_entropy_loss")
+ LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
@@ -179,12 +181,30 @@ def patch_self_attn_lora():
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
- print("patching unsloth attn lora")
+ LOG.info("patching unsloth attn lora", main_process_only=True)
LlamaFlashAttention2.forward = (
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
)
+def integrate_rope_embeddings():
+ import transformers.models.llama.modeling_llama
+ from unsloth.kernels.rope_embedding import fast_rope_embedding
+
+ def apply_rotary_pos_emb( # pylint: disable=unused-argument
+ q, # pylint: disable=invalid-name
+ k, # pylint: disable=invalid-name
+ cos,
+ sin,
+ position_ids=None,
+ unsqueeze_dim=1,
+ ):
+ return fast_rope_embedding(q, k, cos, sin)
+
+ LOG.info("patching unsloth RoPE embeddings", main_process_only=True)
+ transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
+
+
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
from unsloth.kernels import apply_lora_mlp_swiglu
@@ -217,7 +237,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
else:
- logging.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
+ LOG.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
@@ -243,9 +263,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
layer.self_attn.apply_qkv = apply_lora_qkv
else:
layer.self_attn.apply_qkv = original_apply_qkv
- logging.warning(
- "unable to apply unsloth lora qkv patch to layer %d", idx
- )
+ LOG.warning("unable to apply unsloth lora qkv patch to layer %d", idx)
if cfg.unsloth_lora_o:
layer_modules = [
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
@@ -264,6 +282,33 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
layer.self_attn.apply_o = apply_lora_o
else:
layer.self_attn.apply_o = original_apply_o
- logging.warning(
+ LOG.warning(
"unable to apply unsloth lora o_proj patch to layer %d", idx
)
+
+
+def patch_unsloth_layernorm():
+ try:
+ import transformers.models.llama.modeling_llama
+ from unsloth.kernels.rms_layernorm import Fast_RMS_Layernorm
+
+ class LlamaRMSNorm(nn.Module):
+ """LlamaRMSNorm"""
+
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ LlamaRMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ return Fast_RMS_Layernorm.apply(
+ hidden_states, self.weight, self.variance_epsilon, False
+ )
+
+ LOG.info("patching with unsloth.kernels.rms_layernorm")
+ transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
+ except ImportError:
+ LOG.warning("missing unsloth library")
diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
index f0c6fa0ea..f3acbbc68 100644
--- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
+++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py
@@ -7,6 +7,7 @@ Module for pydantic models for configuration
import logging
import os
from enum import Enum
+from importlib.metadata import version
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import BaseModel, Field, conlist, field_validator, model_validator
@@ -596,6 +597,8 @@ class AxolotlInputConfig(
unsloth_lora_mlp: Optional[bool] = None
unsloth_lora_qkv: Optional[bool] = None
unsloth_lora_o: Optional[bool] = None
+ unsloth_rms_norm: Optional[bool] = None
+ unsloth_rope: Optional[bool] = None
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
fsdp: Optional[List[str]] = None
@@ -1164,6 +1167,21 @@ class AxolotlInputConfig(
)
return data
+ @model_validator(mode="before")
+ @classmethod
+ def check_unsloth_xformers_version(cls, data):
+ if (
+ data.get("unsloth_lora_mlp")
+ or data.get("unsloth_lora_qkv")
+ or data.get("unsloth_lora_o")
+ ):
+ xformers_version = version("xformers")
+ if xformers_version == "0.0.27":
+ raise ValueError(
+ "xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1"
+ )
+ return data
+
@model_validator(mode="before")
@classmethod
def check_torch_compile_deepspeed(cls, data):
diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py
index 51ce5a29b..6185f0102 100644
--- a/src/axolotl/utils/models.py
+++ b/src/axolotl/utils/models.py
@@ -1,7 +1,7 @@
"""Module for models and model loading"""
# pylint: disable=too-many-lines
-
+import gc
import logging
import math
import os
@@ -94,7 +94,7 @@ def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDef
"Please make sure to point to a GPTQ model."
)
- if not cfg.gptq and quant_config_exists:
+ if not cfg.gptq and quant_config_exists and not cfg.load_in_4bit:
raise ValueError(
"model_config.quantization_config is set but `gptq` flag is not. "
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
@@ -358,6 +358,10 @@ def load_model(
patch_llama_cross_entropy()
if cfg.flash_attn_rms_norm:
patch_llama_rms_norm()
+ elif cfg.unsloth_rms_norm:
+ from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
+
+ patch_unsloth_layernorm()
if cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import (
integrate_cross_entropy_loss_patch,
@@ -884,6 +888,15 @@ def load_model(
integrate_lora_patch(model, cfg)
+ if cfg.unsloth_rope:
+ from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings
+
+ integrate_rope_embeddings()
+
+ for _ in range(3):
+ gc.collect()
+ torch.cuda.empty_cache()
+
# TODO resume_from_checkpoint handling
return model, lora_config