Unsloth rope (#1767)

* Add unsloth rope embeddings support

* support for models weights in 4bit and do some memory gc

* use accelerate logger

* add unsloth llama rms norm optims

* update docs for unsloth

* more docs info
This commit is contained in:
Wing Lian
2024-07-18 14:54:41 -04:00
committed by GitHub
parent c86c32a627
commit 7830fe04b5
6 changed files with 138 additions and 11 deletions

View File

@@ -1,7 +1,7 @@
"""Module for models and model loading"""
# pylint: disable=too-many-lines
import gc
import logging
import math
import os
@@ -94,7 +94,7 @@ def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDef
"Please make sure to point to a GPTQ model."
)
if not cfg.gptq and quant_config_exists:
if not cfg.gptq and quant_config_exists and not cfg.load_in_4bit:
raise ValueError(
"model_config.quantization_config is set but `gptq` flag is not. "
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
@@ -358,6 +358,10 @@ def load_model(
patch_llama_cross_entropy()
if cfg.flash_attn_rms_norm:
patch_llama_rms_norm()
elif cfg.unsloth_rms_norm:
from axolotl.monkeypatch.unsloth_ import patch_unsloth_layernorm
patch_unsloth_layernorm()
if cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import (
integrate_cross_entropy_loss_patch,
@@ -884,6 +888,15 @@ def load_model(
integrate_lora_patch(model, cfg)
if cfg.unsloth_rope:
from axolotl.monkeypatch.unsloth_ import integrate_rope_embeddings
integrate_rope_embeddings()
for _ in range(3):
gc.collect()
torch.cuda.empty_cache()
# TODO resume_from_checkpoint handling
return model, lora_config