support for llama multipack using updated code/patches (#1754)

* support for llama multipack using updated code/patches

* also support unsloth patches

* incorrect arg

* add config validation for unsloth

* add missing return to validation

* add another missing return to validation
This commit is contained in:
Wing Lian
2024-07-16 17:36:29 -04:00
committed by GitHub
parent cfc533a7f7
commit 5f58555bd0
4 changed files with 95 additions and 21 deletions

View File

@@ -347,6 +347,27 @@ def load_model(
and cfg.sample_packing
):
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
if cfg.is_llama_derived_model:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
patch_llama_cross_entropy,
patch_llama_rms_norm,
)
if cfg.flash_attn_cross_entropy:
patch_llama_cross_entropy()
if cfg.flash_attn_rms_norm:
patch_llama_rms_norm()
if cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import (
integrate_cross_entropy_loss_patch,
)
integrate_cross_entropy_loss_patch()
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
patch_self_attn_lora()
elif cfg.is_llama_derived_model:
# Modify all llama derived models in one block