Unsloth rope (#1767)
* Add unsloth rope embeddings support * support for models weights in 4bit and do some memory gc * use accelerate logger * add unsloth llama rms norm optims * update docs for unsloth * more docs info
This commit is contained in:
@@ -46,6 +46,7 @@ Features:
|
|||||||
- [Multipack](./docs/multipack.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
- [Multipack](./docs/multipack.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||||
- [RLHF & DPO](./docs/rlhf.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
- [RLHF & DPO](./docs/rlhf.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||||
- [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
- [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||||
|
- [Unsloth](./docs/unsloth.qmd)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||||
- [Common Errors](#common-errors-)
|
- [Common Errors](#common-errors-)
|
||||||
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
|
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
|
||||||
- [Debugging Axolotl](#debugging-axolotl)
|
- [Debugging Axolotl](#debugging-axolotl)
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ website:
|
|||||||
- docs/nccl.qmd
|
- docs/nccl.qmd
|
||||||
- docs/mac.qmd
|
- docs/mac.qmd
|
||||||
- docs/multi-node.qmd
|
- docs/multi-node.qmd
|
||||||
|
- docs/unsloth.qmd
|
||||||
- section: "Dataset Formats"
|
- section: "Dataset Formats"
|
||||||
contents: docs/dataset-formats/*
|
contents: docs/dataset-formats/*
|
||||||
- section: "Reference"
|
- section: "Reference"
|
||||||
|
|||||||
49
docs/unsloth.qmd
Normal file
49
docs/unsloth.qmd
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
---
|
||||||
|
title: "Unsloth"
|
||||||
|
description: "Hyper-optimized QLoRA finetuning for single GPUs"
|
||||||
|
---
|
||||||
|
|
||||||
|
### Overview
|
||||||
|
|
||||||
|
Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over
|
||||||
|
standard industry baselines.
|
||||||
|
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up
|
||||||
|
to date libraries.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git"
|
||||||
|
pip install --no-deps --force-reinstall xformers==0.0.26.post1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using unsloth w Axolotl
|
||||||
|
|
||||||
|
Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains.
|
||||||
|
|
||||||
|
Our unsloth integration is currently limited to the following model architectures:
|
||||||
|
- llama
|
||||||
|
|
||||||
|
These options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning
|
||||||
|
```yaml
|
||||||
|
unsloth_lora_mlp: true
|
||||||
|
unsloth_lora_qkv: true
|
||||||
|
unsloth_lora_o: true
|
||||||
|
```
|
||||||
|
|
||||||
|
These options are composable and can be used with multi-gpu finetuning
|
||||||
|
```
|
||||||
|
unsloth_cross_entropy_loss: true
|
||||||
|
unsloth_rms_norm: true
|
||||||
|
unsloth_rope: true
|
||||||
|
```
|
||||||
|
|
||||||
|
### Limitations
|
||||||
|
|
||||||
|
- Single GPU only; e.g. no multi-gpu support
|
||||||
|
- No deepspeed or FSDP support (requires multi-gpu)
|
||||||
|
- LoRA + QLoRA support only. No full fine tunes or fp8 support.
|
||||||
|
- Limited model architecture support. Llama, Phi, Gemma, Mistral only
|
||||||
|
- No MoE support.
|
||||||
@@ -1,18 +1,20 @@
|
|||||||
"""module for patching with unsloth optimizations"""
|
"""module for patching with unsloth optimizations"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
import types
|
import types
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from accelerate.logging import get_logger
|
||||||
from peft import PeftModelForCausalLM
|
from peft import PeftModelForCausalLM
|
||||||
|
from torch import nn
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
LlamaFlashAttention2,
|
LlamaFlashAttention2,
|
||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.monkeypatch.unsloth")
|
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
||||||
|
|
||||||
ORIGINAL_CEL_CODE = """ if labels is not None:
|
ORIGINAL_CEL_CODE = """ if labels is not None:
|
||||||
# Shift so that tokens < n predict n
|
# Shift so that tokens < n predict n
|
||||||
@@ -137,7 +139,7 @@ def integrate_cross_entropy_loss_patch():
|
|||||||
globals(),
|
globals(),
|
||||||
)
|
)
|
||||||
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
print("patching unsloth fast_cross_entropy_loss")
|
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
|
||||||
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
|
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
|
||||||
|
|
||||||
@@ -179,12 +181,30 @@ def patch_self_attn_lora():
|
|||||||
globals(),
|
globals(),
|
||||||
)
|
)
|
||||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
print("patching unsloth attn lora")
|
LOG.info("patching unsloth attn lora", main_process_only=True)
|
||||||
LlamaFlashAttention2.forward = (
|
LlamaFlashAttention2.forward = (
|
||||||
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def integrate_rope_embeddings():
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
from unsloth.kernels.rope_embedding import fast_rope_embedding
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb( # pylint: disable=unused-argument
|
||||||
|
q, # pylint: disable=invalid-name
|
||||||
|
k, # pylint: disable=invalid-name
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
position_ids=None,
|
||||||
|
unsqueeze_dim=1,
|
||||||
|
):
|
||||||
|
return fast_rope_embedding(q, k, cos, sin)
|
||||||
|
|
||||||
|
LOG.info("patching unsloth RoPE embeddings", main_process_only=True)
|
||||||
|
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||||
|
|
||||||
|
|
||||||
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
||||||
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
|
if peft_model.base_model.config.model_type in ["llama", "mistral"]:
|
||||||
from unsloth.kernels import apply_lora_mlp_swiglu
|
from unsloth.kernels import apply_lora_mlp_swiglu
|
||||||
@@ -217,7 +237,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
|
|||||||
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
|
if is_mlp_lora and mlp_no_bias and mlp_not_dora:
|
||||||
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
|
layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp)
|
||||||
else:
|
else:
|
||||||
logging.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
|
LOG.warning("unable to apply unsloth lora mlp patch to layer %d", idx)
|
||||||
|
|
||||||
|
|
||||||
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
||||||
@@ -243,9 +263,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
|||||||
layer.self_attn.apply_qkv = apply_lora_qkv
|
layer.self_attn.apply_qkv = apply_lora_qkv
|
||||||
else:
|
else:
|
||||||
layer.self_attn.apply_qkv = original_apply_qkv
|
layer.self_attn.apply_qkv = original_apply_qkv
|
||||||
logging.warning(
|
LOG.warning("unable to apply unsloth lora qkv patch to layer %d", idx)
|
||||||
"unable to apply unsloth lora qkv patch to layer %d", idx
|
|
||||||
)
|
|
||||||
if cfg.unsloth_lora_o:
|
if cfg.unsloth_lora_o:
|
||||||
layer_modules = [
|
layer_modules = [
|
||||||
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"]
|
||||||
@@ -264,6 +282,33 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
|
|||||||
layer.self_attn.apply_o = apply_lora_o
|
layer.self_attn.apply_o = apply_lora_o
|
||||||
else:
|
else:
|
||||||
layer.self_attn.apply_o = original_apply_o
|
layer.self_attn.apply_o = original_apply_o
|
||||||
logging.warning(
|
LOG.warning(
|
||||||
"unable to apply unsloth lora o_proj patch to layer %d", idx
|
"unable to apply unsloth lora o_proj patch to layer %d", idx
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_unsloth_layernorm():
|
||||||
|
try:
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
from unsloth.kernels.rms_layernorm import Fast_RMS_Layernorm
|
||||||
|
|
||||||
|
class LlamaRMSNorm(nn.Module):
|
||||||
|
"""LlamaRMSNorm"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
"""
|
||||||
|
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
return Fast_RMS_Layernorm.apply(
|
||||||
|
hidden_states, self.weight, self.variance_epsilon, False
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("patching with unsloth.kernels.rms_layernorm")
|
||||||
|
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||||
|
except ImportError:
|
||||||
|
LOG.warning("missing unsloth library")
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ Module for pydantic models for configuration
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from importlib.metadata import version
|
||||||
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, conlist, field_validator, model_validator
|
from pydantic import BaseModel, Field, conlist, field_validator, model_validator
|
||||||
@@ -596,6 +597,8 @@ class AxolotlInputConfig(
|
|||||||
unsloth_lora_mlp: Optional[bool] = None
|
unsloth_lora_mlp: Optional[bool] = None
|
||||||
unsloth_lora_qkv: Optional[bool] = None
|
unsloth_lora_qkv: Optional[bool] = None
|
||||||
unsloth_lora_o: Optional[bool] = None
|
unsloth_lora_o: Optional[bool] = None
|
||||||
|
unsloth_rms_norm: Optional[bool] = None
|
||||||
|
unsloth_rope: Optional[bool] = None
|
||||||
|
|
||||||
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
|
deepspeed: Optional[Union[str, Dict[str, Any]]] = None
|
||||||
fsdp: Optional[List[str]] = None
|
fsdp: Optional[List[str]] = None
|
||||||
@@ -1164,6 +1167,21 @@ class AxolotlInputConfig(
|
|||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_unsloth_xformers_version(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("unsloth_lora_mlp")
|
||||||
|
or data.get("unsloth_lora_qkv")
|
||||||
|
or data.get("unsloth_lora_o")
|
||||||
|
):
|
||||||
|
xformers_version = version("xformers")
|
||||||
|
if xformers_version == "0.0.27":
|
||||||
|
raise ValueError(
|
||||||
|
"xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_torch_compile_deepspeed(cls, data):
|
def check_torch_compile_deepspeed(cls, data):
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Module for models and model loading"""
|
"""Module for models and model loading"""
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@@ -94,7 +94,7 @@ def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDef
|
|||||||
"Please make sure to point to a GPTQ model."
|
"Please make sure to point to a GPTQ model."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not cfg.gptq and quant_config_exists:
|
if not cfg.gptq and quant_config_exists and not cfg.load_in_4bit:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"model_config.quantization_config is set but `gptq` flag is not. "
|
"model_config.quantization_config is set but `gptq` flag is not. "
|
||||||
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
|
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
|
||||||
@@ -358,6 +358,10 @@ def load_model(
|
|||||||
patch_llama_cross_entropy()
|
patch_llama_cross_entropy()
|
||||||
if cfg.flash_attn_rms_norm:
|
if cfg.flash_attn_rms_norm:
|
||||||
patch_llama_rms_norm()
|
patch_llama_rms_norm()
|
||||||
|
elif cfg.unsloth_rms_norm:
|
||||||
|
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
|
||||||
|
|
||||||
|
patch_unsloth_layernorm()
|
||||||
if cfg.unsloth_cross_entropy_loss:
|
if cfg.unsloth_cross_entropy_loss:
|
||||||
from axolotl.monkeypatch.unsloth_ import (
|
from axolotl.monkeypatch.unsloth_ import (
|
||||||
integrate_cross_entropy_loss_patch,
|
integrate_cross_entropy_loss_patch,
|
||||||
@@ -884,6 +888,15 @@ def load_model(
|
|||||||
|
|
||||||
integrate_lora_patch(model, cfg)
|
integrate_lora_patch(model, cfg)
|
||||||
|
|
||||||
|
if cfg.unsloth_rope:
|
||||||
|
from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings
|
||||||
|
|
||||||
|
integrate_rope_embeddings()
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# TODO resume_from_checkpoint handling
|
# TODO resume_from_checkpoint handling
|
||||||
return model, lora_config
|
return model, lora_config
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user