From 7830fe04b5ef226e0d0eff8a7285ce263ee3de47 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 18 Jul 2024 14:54:41 -0400 Subject: [PATCH] Unsloth rope (#1767) * Add unsloth rope embeddings support * support for models weights in 4bit and do some memory gc * use accelerate logger * add unsloth llama rms norm optims * update docs for unsloth * more docs info --- README.md | 1 + _quarto.yml | 1 + docs/unsloth.qmd | 49 +++++++++++++++ src/axolotl/monkeypatch/unsloth_.py | 63 ++++++++++++++++--- .../config/models/input/v0_4_1/__init__.py | 18 ++++++ src/axolotl/utils/models.py | 17 ++++- 6 files changed, 138 insertions(+), 11 deletions(-) create mode 100644 docs/unsloth.qmd diff --git a/README.md b/README.md index 0a4e2e5a1..fd293bd04 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ Features: - [Multipack](./docs/multipack.qmd) - [RLHF & DPO](./docs/rlhf.qmd) - [Dataset Pre-Processing](./docs/dataset_preprocessing.qmd) + - [Unsloth](./docs/unsloth.qmd) - [Common Errors](#common-errors-) - [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training) - [Debugging Axolotl](#debugging-axolotl) diff --git a/_quarto.yml b/_quarto.yml index 009fa8056..6b2eed971 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -36,6 +36,7 @@ website: - docs/nccl.qmd - docs/mac.qmd - docs/multi-node.qmd + - docs/unsloth.qmd - section: "Dataset Formats" contents: docs/dataset-formats/* - section: "Reference" diff --git a/docs/unsloth.qmd b/docs/unsloth.qmd new file mode 100644 index 000000000..390609fd3 --- /dev/null +++ b/docs/unsloth.qmd @@ -0,0 +1,49 @@ +--- +title: "Unsloth" +description: "Hyper-optimized QLoRA finetuning for single GPUs" +--- + +### Overview + +Unsloth provides hand-written optimized kernels for LLM finetuning that slightly improve speed and VRAM over +standard industry baselines. + + +### Installation + +The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up +to date libraries. + +```bash +pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git" +pip install --no-deps --force-reinstall xformers==0.0.26.post1 +``` + +### Using unsloth w Axolotl + +Axolotl exposes a few configuration options to try out unsloth and get most of the performance gains. + +Our unsloth integration is currently limited to the following model architectures: + - llama + +These options are specific to LoRA finetuning and cannot be used for multi-GPU finetuning +```yaml +unsloth_lora_mlp: true +unsloth_lora_qkv: true +unsloth_lora_o: true +``` + +These options are composable and can be used with multi-gpu finetuning +``` +unsloth_cross_entropy_loss: true +unsloth_rms_norm: true +unsloth_rope: true +``` + +### Limitations + +- Single GPU only; e.g. no multi-gpu support +- No deepspeed or FSDP support (requires multi-gpu) +- LoRA + QLoRA support only. No full fine tunes or fp8 support. +- Limited model architecture support. Llama, Phi, Gemma, Mistral only +- No MoE support. diff --git a/src/axolotl/monkeypatch/unsloth_.py b/src/axolotl/monkeypatch/unsloth_.py index 6af3046e1..b1f0bddc0 100644 --- a/src/axolotl/monkeypatch/unsloth_.py +++ b/src/axolotl/monkeypatch/unsloth_.py @@ -1,18 +1,20 @@ """module for patching with unsloth optimizations""" import inspect -import logging import re import types from typing import Tuple +import torch +from accelerate.logging import get_logger from peft import PeftModelForCausalLM +from torch import nn from transformers.models.llama.modeling_llama import ( LlamaFlashAttention2, LlamaForCausalLM, ) -LOG = logging.getLogger("axolotl.monkeypatch.unsloth") +LOG = get_logger("axolotl.monkeypatch.unsloth") ORIGINAL_CEL_CODE = """ if labels is not None: # Shift so that tokens < n predict n @@ -137,7 +139,7 @@ def integrate_cross_entropy_loss_patch(): globals(), ) exec(forward, globals()) # pylint: disable=exec-used # nosec B102 - print("patching unsloth fast_cross_entropy_loss") + LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True) LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821 @@ -179,12 +181,30 @@ def patch_self_attn_lora(): globals(), ) exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102 - print("patching unsloth attn lora") + LOG.info("patching unsloth attn lora", main_process_only=True) LlamaFlashAttention2.forward = ( unsloth_attn_forward # pylint: disable=undefined-variable # noqa: F821 ) +def integrate_rope_embeddings(): + import transformers.models.llama.modeling_llama + from unsloth.kernels.rope_embedding import fast_rope_embedding + + def apply_rotary_pos_emb( # pylint: disable=unused-argument + q, # pylint: disable=invalid-name + k, # pylint: disable=invalid-name + cos, + sin, + position_ids=None, + unsqueeze_dim=1, + ): + return fast_rope_embedding(q, k, cos, sin) + + LOG.info("patching unsloth RoPE embeddings", main_process_only=True) + transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb + + def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM): if peft_model.base_model.config.model_type in ["llama", "mistral"]: from unsloth.kernels import apply_lora_mlp_swiglu @@ -217,7 +237,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM): if is_mlp_lora and mlp_no_bias and mlp_not_dora: layer.mlp.forward = types.MethodType(apply_lora_mlp, layer.mlp) else: - logging.warning("unable to apply unsloth lora mlp patch to layer %d", idx) + LOG.warning("unable to apply unsloth lora mlp patch to layer %d", idx) def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg): @@ -243,9 +263,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg): layer.self_attn.apply_qkv = apply_lora_qkv else: layer.self_attn.apply_qkv = original_apply_qkv - logging.warning( - "unable to apply unsloth lora qkv patch to layer %d", idx - ) + LOG.warning("unable to apply unsloth lora qkv patch to layer %d", idx) if cfg.unsloth_lora_o: layer_modules = [ getattr(layer.self_attn, linear_proj) for linear_proj in ["o_proj"] @@ -264,6 +282,33 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg): layer.self_attn.apply_o = apply_lora_o else: layer.self_attn.apply_o = original_apply_o - logging.warning( + LOG.warning( "unable to apply unsloth lora o_proj patch to layer %d", idx ) + + +def patch_unsloth_layernorm(): + try: + import transformers.models.llama.modeling_llama + from unsloth.kernels.rms_layernorm import Fast_RMS_Layernorm + + class LlamaRMSNorm(nn.Module): + """LlamaRMSNorm""" + + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + return Fast_RMS_Layernorm.apply( + hidden_states, self.weight, self.variance_epsilon, False + ) + + LOG.info("patching with unsloth.kernels.rms_layernorm") + transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm + except ImportError: + LOG.warning("missing unsloth library") diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index f0c6fa0ea..f3acbbc68 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -7,6 +7,7 @@ Module for pydantic models for configuration import logging import os from enum import Enum +from importlib.metadata import version from typing import Any, Dict, List, Literal, Optional, Tuple, Union from pydantic import BaseModel, Field, conlist, field_validator, model_validator @@ -596,6 +597,8 @@ class AxolotlInputConfig( unsloth_lora_mlp: Optional[bool] = None unsloth_lora_qkv: Optional[bool] = None unsloth_lora_o: Optional[bool] = None + unsloth_rms_norm: Optional[bool] = None + unsloth_rope: Optional[bool] = None deepspeed: Optional[Union[str, Dict[str, Any]]] = None fsdp: Optional[List[str]] = None @@ -1164,6 +1167,21 @@ class AxolotlInputConfig( ) return data + @model_validator(mode="before") + @classmethod + def check_unsloth_xformers_version(cls, data): + if ( + data.get("unsloth_lora_mlp") + or data.get("unsloth_lora_qkv") + or data.get("unsloth_lora_o") + ): + xformers_version = version("xformers") + if xformers_version == "0.0.27": + raise ValueError( + "xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1" + ) + return data + @model_validator(mode="before") @classmethod def check_torch_compile_deepspeed(cls, data): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 51ce5a29b..6185f0102 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -1,7 +1,7 @@ """Module for models and model loading""" # pylint: disable=too-many-lines - +import gc import logging import math import os @@ -94,7 +94,7 @@ def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDef "Please make sure to point to a GPTQ model." ) - if not cfg.gptq and quant_config_exists: + if not cfg.gptq and quant_config_exists and not cfg.load_in_4bit: raise ValueError( "model_config.quantization_config is set but `gptq` flag is not. " "Please use the `gptq` flag to train quantized model or point to a non-quantized model." @@ -358,6 +358,10 @@ def load_model( patch_llama_cross_entropy() if cfg.flash_attn_rms_norm: patch_llama_rms_norm() + elif cfg.unsloth_rms_norm: + from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm + + patch_unsloth_layernorm() if cfg.unsloth_cross_entropy_loss: from axolotl.monkeypatch.unsloth_ import ( integrate_cross_entropy_loss_patch, @@ -884,6 +888,15 @@ def load_model( integrate_lora_patch(model, cfg) + if cfg.unsloth_rope: + from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings + + integrate_rope_embeddings() + + for _ in range(3): + gc.collect() + torch.cuda.empty_cache() + # TODO resume_from_checkpoint handling return model, lora_config