add noisy embedding (#721)
* add noisy embedding * fix format * Update README.md * Update README.md * linter issues * caseus fixes --------- Co-authored-by: Maxime <maxime@nope.no>
This commit is contained in:
@@ -672,6 +672,11 @@ adam_epsilon:
|
|||||||
# Gradient clipping max norm
|
# Gradient clipping max norm
|
||||||
max_grad_norm:
|
max_grad_norm:
|
||||||
|
|
||||||
|
# Augmentation techniques
|
||||||
|
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
||||||
|
# currently only supported on Llama and Mistral
|
||||||
|
noisy_embedding_alpha:
|
||||||
|
|
||||||
# Whether to bettertransformers
|
# Whether to bettertransformers
|
||||||
flash_optimum:
|
flash_optimum:
|
||||||
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||||
|
|||||||
40
src/axolotl/monkeypatch/llama_embeddings_hijack.py
Normal file
40
src/axolotl/monkeypatch/llama_embeddings_hijack.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""
|
||||||
|
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
def noised_embed(orig_embed, noise_alpha, model):
|
||||||
|
def new_func(input_ids):
|
||||||
|
# during training, we add noise to the embedding
|
||||||
|
# during generation, we don't add noise to the embedding
|
||||||
|
if model.training:
|
||||||
|
embed_init = orig_embed(input_ids)
|
||||||
|
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
||||||
|
mag_norm = noise_alpha / torch.sqrt(dims)
|
||||||
|
return embed_init + torch.zeros_like(embed_init).uniform_(
|
||||||
|
-mag_norm, mag_norm
|
||||||
|
)
|
||||||
|
return orig_embed(input_ids)
|
||||||
|
|
||||||
|
return new_func
|
||||||
|
|
||||||
|
def post_init(orig_post_init):
|
||||||
|
def new_func(self):
|
||||||
|
orig_post_init(self)
|
||||||
|
self.embed_tokens.forward = noised_embed(
|
||||||
|
self.embed_tokens.forward, noise_alpha, self
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_func
|
||||||
|
|
||||||
|
transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(
|
||||||
|
transformers.models.llama.modeling_llama.LlamaModel.post_init
|
||||||
|
)
|
||||||
40
src/axolotl/monkeypatch/mistral_embeddings_hijack.py
Normal file
40
src/axolotl/monkeypatch/mistral_embeddings_hijack.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""
|
||||||
|
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers.models.mistral.modeling_mistral
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
def noised_embed(orig_embed, noise_alpha, model):
|
||||||
|
def new_func(input_ids):
|
||||||
|
# during training, we add noise to the embedding
|
||||||
|
# during generation, we don't add noise to the embedding
|
||||||
|
if model.training:
|
||||||
|
embed_init = orig_embed(input_ids)
|
||||||
|
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
||||||
|
mag_norm = noise_alpha / torch.sqrt(dims)
|
||||||
|
return embed_init + torch.zeros_like(embed_init).uniform_(
|
||||||
|
-mag_norm, mag_norm
|
||||||
|
)
|
||||||
|
return orig_embed(input_ids)
|
||||||
|
|
||||||
|
return new_func
|
||||||
|
|
||||||
|
def post_init(orig_post_init):
|
||||||
|
def new_func(self):
|
||||||
|
orig_post_init(self)
|
||||||
|
self.embed_tokens.forward = noised_embed(
|
||||||
|
self.embed_tokens.forward, noise_alpha, self
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_func
|
||||||
|
|
||||||
|
transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(
|
||||||
|
transformers.models.mistral.modeling_mistral.MistralModel.post_init
|
||||||
|
)
|
||||||
@@ -180,6 +180,26 @@ def load_model(
|
|||||||
LOG.info("patching with flash attention")
|
LOG.info("patching with flash attention")
|
||||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||||
|
|
||||||
|
if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
|
||||||
|
from axolotl.monkeypatch.llama_embeddings_hijack import (
|
||||||
|
replace_llama_embeddings_with_uniform_distribution,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("patching with noisy embeddings")
|
||||||
|
replace_llama_embeddings_with_uniform_distribution(
|
||||||
|
noise_alpha=cfg.noisy_embedding_alpha
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
|
||||||
|
from axolotl.monkeypatch.mistral_embeddings_hijack import (
|
||||||
|
replace_mistral_embeddings_with_uniform_distribution,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG.info("patching with noisy embeddings")
|
||||||
|
replace_mistral_embeddings_with_uniform_distribution(
|
||||||
|
noise_alpha=cfg.noisy_embedding_alpha
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
||||||
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
||||||
replace_llama_rope_with_xpos_rope,
|
replace_llama_rope_with_xpos_rope,
|
||||||
|
|||||||
Reference in New Issue
Block a user