Compare commits
2 Commits
feat/lmeva
...
neft-v2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
080612219b | ||
|
|
f95858d369 |
@@ -42,11 +42,21 @@ def replace_llama_attn_with_flash_attn(
|
||||
packed: Optional[bool] = False,
|
||||
cross_entropy: Optional[bool] = False,
|
||||
rms_norm: Optional[bool] = False,
|
||||
noisy_embeddings_alpha: Optional[int] = False,
|
||||
):
|
||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||
_prepare_decoder_attention_mask
|
||||
)
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
|
||||
if noisy_embeddings_alpha:
|
||||
transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = partial(
|
||||
llama_model_get_inputs_embeds, noisy_embeddings_alpha=noisy_embeddings_alpha
|
||||
)
|
||||
else:
|
||||
transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = (
|
||||
llama_model_get_inputs_embeds
|
||||
)
|
||||
|
||||
if packed:
|
||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
||||
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
||||
@@ -411,6 +421,28 @@ def generate_qkv(
|
||||
)
|
||||
|
||||
|
||||
def llama_model_get_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
noisy_embeddings_alpha: Optional[int] = None,
|
||||
):
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if noisy_embeddings_alpha:
|
||||
input_mask = attention_mask.to(inputs_embeds) # B x L
|
||||
input_lengths = torch.sum(input_mask, 1) # B
|
||||
|
||||
noise_ = torch.zeros_like(inputs_embeds).uniform_(-1, 1)
|
||||
delta = noise_ * input_mask.unsqueeze(2)
|
||||
dims = input_lengths * inputs_embeds.size(-1)
|
||||
mag = noisy_embeddings_alpha / torch.sqrt(dims)
|
||||
delta = (delta * mag.view(-1, 1, 1)).detach()
|
||||
inputs_embeds += delta
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
def llama_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
@@ -477,7 +509,8 @@ def llama_model_forward(
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
inputs_embeds = self.get_inputs_embeds(input_ids, attention_mask)
|
||||
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
|
||||
@@ -136,7 +136,11 @@ def load_model(
|
||||
|
||||
replace_stablelm_attn_with_flash_attn(cfg.base_model)
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
|
||||
if (
|
||||
cfg.is_llama_derived_model
|
||||
and cfg.flash_attention
|
||||
and (cfg.noisy_embeddings_alpha or cfg.sample_packing)
|
||||
):
|
||||
if cfg.device not in ["mps", "cpu"] and not inference:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
replace_llama_attn_with_flash_attn,
|
||||
@@ -147,6 +151,7 @@ def load_model(
|
||||
packed=cfg.sample_packing,
|
||||
cross_entropy=cfg.flash_attn_cross_entropy,
|
||||
rms_norm=cfg.flash_attn_rms_norm,
|
||||
noisy_embeddings_alpha=cfg.noisy_embeddings_alpha,
|
||||
)
|
||||
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
@@ -180,16 +185,16 @@ def load_model(
|
||||
LOG.info("patching with flash attention")
|
||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
|
||||
from axolotl.monkeypatch.llama_embeddings_hijack import (
|
||||
replace_llama_embeddings_with_uniform_distribution,
|
||||
)
|
||||
|
||||
LOG.info("patching with noisy embeddings")
|
||||
replace_llama_embeddings_with_uniform_distribution(
|
||||
noise_alpha=cfg.noisy_embedding_alpha
|
||||
)
|
||||
|
||||
# if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
|
||||
# from axolotl.monkeypatch.llama_embeddings_hijack import (
|
||||
# replace_llama_embeddings_with_uniform_distribution,
|
||||
# )
|
||||
#
|
||||
# LOG.info("patching with noisy embeddings")
|
||||
# replace_llama_embeddings_with_uniform_distribution(
|
||||
# noise_alpha=cfg.noisy_embedding_alpha
|
||||
# )
|
||||
#
|
||||
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
|
||||
from axolotl.monkeypatch.mistral_embeddings_hijack import (
|
||||
replace_mistral_embeddings_with_uniform_distribution,
|
||||
|
||||
Reference in New Issue
Block a user