Compare commits

...

2 Commits
dft ... neft-v2

Author SHA1 Message Date
Wing Lian
080612219b use even if not using sample packing 2023-10-13 17:54:35 -04:00
Wing Lian
f95858d369 alternate impl of NEFT 2023-10-13 17:45:24 -04:00
2 changed files with 50 additions and 12 deletions

View File

@@ -42,11 +42,21 @@ def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False,
rms_norm: Optional[bool] = False,
noisy_embeddings_alpha: Optional[int] = False,
):
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
if noisy_embeddings_alpha:
transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = partial(
llama_model_get_inputs_embeds, noisy_embeddings_alpha=noisy_embeddings_alpha
)
else:
transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = (
llama_model_get_inputs_embeds
)
if packed:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaModel.forward = (
@@ -411,6 +421,28 @@ def generate_qkv(
)
def llama_model_get_inputs_embeds(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
noisy_embeddings_alpha: Optional[int] = None,
):
inputs_embeds = self.embed_tokens(input_ids)
if noisy_embeddings_alpha:
input_mask = attention_mask.to(inputs_embeds) # B x L
input_lengths = torch.sum(input_mask, 1) # B
noise_ = torch.zeros_like(inputs_embeds).uniform_(-1, 1)
delta = noise_ * input_mask.unsqueeze(2)
dims = input_lengths * inputs_embeds.size(-1)
mag = noisy_embeddings_alpha / torch.sqrt(dims)
delta = (delta * mag.view(-1, 1, 1)).detach()
inputs_embeds += delta
return inputs_embeds
def llama_model_forward(
self,
input_ids: torch.LongTensor = None,
@@ -477,7 +509,8 @@ def llama_model_forward(
cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = self.get_inputs_embeds(input_ids, attention_mask)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(

View File

@@ -136,7 +136,11 @@ def load_model(
replace_stablelm_attn_with_flash_attn(cfg.base_model)
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
if (
cfg.is_llama_derived_model
and cfg.flash_attention
and (cfg.noisy_embeddings_alpha or cfg.sample_packing)
):
if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
@@ -147,6 +151,7 @@ def load_model(
packed=cfg.sample_packing,
cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm,
noisy_embeddings_alpha=cfg.noisy_embeddings_alpha,
)
elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
@@ -180,16 +185,16 @@ def load_model(
LOG.info("patching with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.llama_embeddings_hijack import (
replace_llama_embeddings_with_uniform_distribution,
)
LOG.info("patching with noisy embeddings")
replace_llama_embeddings_with_uniform_distribution(
noise_alpha=cfg.noisy_embedding_alpha
)
# if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
# from axolotl.monkeypatch.llama_embeddings_hijack import (
# replace_llama_embeddings_with_uniform_distribution,
# )
#
# LOG.info("patching with noisy embeddings")
# replace_llama_embeddings_with_uniform_distribution(
# noise_alpha=cfg.noisy_embedding_alpha
# )
#
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.mistral_embeddings_hijack import (
replace_mistral_embeddings_with_uniform_distribution,