Unsloth optims for Llama (#1609)
* WIP for unsloth integrations * import the unsloth code in the right context * add unsloth mlp, qkv, o lora optimizations * apply unsloth mlp and qkv kernels
This commit is contained in:
@@ -390,6 +390,16 @@ def load_model(
|
||||
"Shifted-sparse attention not currently implemented without flash attention."
|
||||
)
|
||||
|
||||
if cfg.unsloth_cross_entropy_loss:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
||||
|
||||
integrate_cross_entropy_loss_patch()
|
||||
|
||||
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora()
|
||||
|
||||
# Modify mistral derived models
|
||||
if (
|
||||
cfg.model_config_type == "mistral"
|
||||
@@ -828,6 +838,15 @@ def load_model(
|
||||
if cfg.adapter is not None:
|
||||
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
||||
|
||||
if cfg.unsloth_lora_mlp:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch
|
||||
|
||||
integrate_lora_mlp_patch(model)
|
||||
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_lora_patch
|
||||
|
||||
integrate_lora_patch(model, cfg)
|
||||
|
||||
# TODO resume_from_checkpoint handling
|
||||
return model, lora_config
|
||||
|
||||
|
||||
Reference in New Issue
Block a user