refactor neft patch to be more re-usable similar to trl's impl (#796)

This commit is contained in:
Wing Lian
2023-10-29 04:33:13 -04:00
committed by GitHub
parent 8b79ff0e94
commit 827ec3d274
6 changed files with 88 additions and 101 deletions

View File

@@ -180,26 +180,6 @@ def load_model(
LOG.info("patching with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.llama_embeddings_hijack import (
replace_llama_embeddings_with_uniform_distribution,
)
LOG.info("patching with noisy embeddings")
replace_llama_embeddings_with_uniform_distribution(
noise_alpha=cfg.noisy_embedding_alpha
)
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.mistral_embeddings_hijack import (
replace_mistral_embeddings_with_uniform_distribution,
)
LOG.info("patching with noisy embeddings")
replace_mistral_embeddings_with_uniform_distribution(
noise_alpha=cfg.noisy_embedding_alpha
)
if cfg.is_llama_derived_model and cfg.xpos_rope:
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
replace_llama_rope_with_xpos_rope,