From f95858d3695162b3b3fcfcbdc0f637fa570dfd68 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 13 Oct 2023 17:45:24 -0400 Subject: [PATCH] alternate impl of NEFT --- .../monkeypatch/llama_attn_hijack_flash.py | 35 ++++++++++++++++++- src/axolotl/utils/models.py | 21 +++++------ 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 4f6b71575..fb271e6aa 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -42,11 +42,21 @@ def replace_llama_attn_with_flash_attn( packed: Optional[bool] = False, cross_entropy: Optional[bool] = False, rms_norm: Optional[bool] = False, + noisy_embeddings_alpha: Optional[int] = False, ): transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access _prepare_decoder_attention_mask ) transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward + if noisy_embeddings_alpha: + transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = partial( + llama_model_get_inputs_embeds, noisy_embeddings_alpha=noisy_embeddings_alpha + ) + else: + transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = ( + llama_model_get_inputs_embeds + ) + if packed: transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer transformers.models.llama.modeling_llama.LlamaModel.forward = ( @@ -411,6 +421,28 @@ def generate_qkv( ) +def llama_model_get_inputs_embeds( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + noisy_embeddings_alpha: Optional[int] = None, +): + inputs_embeds = self.embed_tokens(input_ids) + + if noisy_embeddings_alpha: + input_mask = attention_mask.to(inputs_embeds) # B x L + input_lengths = torch.sum(input_mask, 1) # B + + noise_ = torch.zeros_like(inputs_embeds).uniform_(-1, 1) + delta = noise_ * input_mask.unsqueeze(2) + dims = input_lengths * inputs_embeds.size(-1) + mag = noisy_embeddings_alpha / torch.sqrt(dims) + delta = (delta * mag.view(-1, 1, 1)).detach() + inputs_embeds += delta + + return inputs_embeds + + def llama_model_forward( self, input_ids: torch.LongTensor = None, @@ -477,7 +509,8 @@ def llama_model_forward( cu_seqlens = cu_seqlens.squeeze() if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.get_inputs_embeds(input_ids, attention_mask) + # embed positions if attention_mask is None: attention_mask = torch.ones( diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c133e9eb6..19b13a342 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -147,6 +147,7 @@ def load_model( packed=cfg.sample_packing, cross_entropy=cfg.flash_attn_cross_entropy, rms_norm=cfg.flash_attn_rms_norm, + noisy_embeddings_alpha=cfg.noisy_embeddings_alpha, ) elif cfg.is_llama_derived_model and cfg.xformers_attention: from axolotl.monkeypatch.llama_attn_hijack_xformers import ( @@ -180,16 +181,16 @@ def load_model( LOG.info("patching with flash attention") replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) - if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha: - from axolotl.monkeypatch.llama_embeddings_hijack import ( - replace_llama_embeddings_with_uniform_distribution, - ) - - LOG.info("patching with noisy embeddings") - replace_llama_embeddings_with_uniform_distribution( - noise_alpha=cfg.noisy_embedding_alpha - ) - + # if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha: + # from axolotl.monkeypatch.llama_embeddings_hijack import ( + # replace_llama_embeddings_with_uniform_distribution, + # ) + # + # LOG.info("patching with noisy embeddings") + # replace_llama_embeddings_with_uniform_distribution( + # noise_alpha=cfg.noisy_embedding_alpha + # ) + # if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha: from axolotl.monkeypatch.mistral_embeddings_hijack import ( replace_mistral_embeddings_with_uniform_distribution,