Unsloth optims for Llama (#1609)

* WIP for unsloth integrations

* import the unsloth code in the right context

* add unsloth mlp, qkv, o lora optimizations

* apply unsloth mlp and qkv kernels
This commit is contained in:
Wing Lian
2024-05-20 09:55:06 -04:00
committed by GitHub
parent 702a669cad
commit 8a1572a831
3 changed files with 291 additions and 0 deletions

View File

@@ -390,6 +390,16 @@ def load_model(
"Shifted-sparse attention not currently implemented without flash attention."
)
if cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch()
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
# Modify mistral derived models
if (
cfg.model_config_type == "mistral"
@@ -828,6 +838,15 @@ def load_model(
if cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", model.device)
if cfg.unsloth_lora_mlp:
from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
integrate_lora_mlp_patch(model)
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import integrate_lora_patch
integrate_lora_patch(model, cfg)
# TODO resume_from_checkpoint handling
return model, lora_config