Compare commits

...

2 Commits

Author SHA1 Message Date
Wing Lian
080612219b use even if not using sample packing 2023-10-13 17:54:35 -04:00
Wing Lian
f95858d369 alternate impl of NEFT 2023-10-13 17:45:24 -04:00
2 changed files with 50 additions and 12 deletions

View File

@@ -42,11 +42,21 @@ def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False, packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False, cross_entropy: Optional[bool] = False,
rms_norm: Optional[bool] = False, rms_norm: Optional[bool] = False,
noisy_embeddings_alpha: Optional[int] = False,
): ):
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask _prepare_decoder_attention_mask
) )
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
if noisy_embeddings_alpha:
transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = partial(
llama_model_get_inputs_embeds, noisy_embeddings_alpha=noisy_embeddings_alpha
)
else:
transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = (
llama_model_get_inputs_embeds
)
if packed: if packed:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaModel.forward = ( transformers.models.llama.modeling_llama.LlamaModel.forward = (
@@ -411,6 +421,28 @@ def generate_qkv(
) )
def llama_model_get_inputs_embeds(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
noisy_embeddings_alpha: Optional[int] = None,
):
inputs_embeds = self.embed_tokens(input_ids)
if noisy_embeddings_alpha:
input_mask = attention_mask.to(inputs_embeds) # B x L
input_lengths = torch.sum(input_mask, 1) # B
noise_ = torch.zeros_like(inputs_embeds).uniform_(-1, 1)
delta = noise_ * input_mask.unsqueeze(2)
dims = input_lengths * inputs_embeds.size(-1)
mag = noisy_embeddings_alpha / torch.sqrt(dims)
delta = (delta * mag.view(-1, 1, 1)).detach()
inputs_embeds += delta
return inputs_embeds
def llama_model_forward( def llama_model_forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
@@ -477,7 +509,8 @@ def llama_model_forward(
cu_seqlens = cu_seqlens.squeeze() cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.get_inputs_embeds(input_ids, attention_mask)
# embed positions # embed positions
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones( attention_mask = torch.ones(

View File

@@ -136,7 +136,11 @@ def load_model(
replace_stablelm_attn_with_flash_attn(cfg.base_model) replace_stablelm_attn_with_flash_attn(cfg.base_model)
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing: if (
cfg.is_llama_derived_model
and cfg.flash_attention
and (cfg.noisy_embeddings_alpha or cfg.sample_packing)
):
if cfg.device not in ["mps", "cpu"] and not inference: if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import ( from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn, replace_llama_attn_with_flash_attn,
@@ -147,6 +151,7 @@ def load_model(
packed=cfg.sample_packing, packed=cfg.sample_packing,
cross_entropy=cfg.flash_attn_cross_entropy, cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm, rms_norm=cfg.flash_attn_rms_norm,
noisy_embeddings_alpha=cfg.noisy_embeddings_alpha,
) )
elif cfg.is_llama_derived_model and cfg.xformers_attention: elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import ( from axolotl.monkeypatch.llama_attn_hijack_xformers import (
@@ -180,16 +185,16 @@ def load_model(
LOG.info("patching with flash attention") LOG.info("patching with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha: # if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.llama_embeddings_hijack import ( # from axolotl.monkeypatch.llama_embeddings_hijack import (
replace_llama_embeddings_with_uniform_distribution, # replace_llama_embeddings_with_uniform_distribution,
) # )
#
LOG.info("patching with noisy embeddings") # LOG.info("patching with noisy embeddings")
replace_llama_embeddings_with_uniform_distribution( # replace_llama_embeddings_with_uniform_distribution(
noise_alpha=cfg.noisy_embedding_alpha # noise_alpha=cfg.noisy_embedding_alpha
) # )
#
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha: if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.mistral_embeddings_hijack import ( from axolotl.monkeypatch.mistral_embeddings_hijack import (
replace_mistral_embeddings_with_uniform_distribution, replace_mistral_embeddings_with_uniform_distribution,