Compare commits
2 Commits
kd-trainer
...
neft-v2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
080612219b | ||
|
|
f95858d369 |
@@ -42,11 +42,21 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
packed: Optional[bool] = False,
|
packed: Optional[bool] = False,
|
||||||
cross_entropy: Optional[bool] = False,
|
cross_entropy: Optional[bool] = False,
|
||||||
rms_norm: Optional[bool] = False,
|
rms_norm: Optional[bool] = False,
|
||||||
|
noisy_embeddings_alpha: Optional[int] = False,
|
||||||
):
|
):
|
||||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||||
_prepare_decoder_attention_mask
|
_prepare_decoder_attention_mask
|
||||||
)
|
)
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
|
||||||
|
if noisy_embeddings_alpha:
|
||||||
|
transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = partial(
|
||||||
|
llama_model_get_inputs_embeds, noisy_embeddings_alpha=noisy_embeddings_alpha
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = (
|
||||||
|
llama_model_get_inputs_embeds
|
||||||
|
)
|
||||||
|
|
||||||
if packed:
|
if packed:
|
||||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
||||||
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
||||||
@@ -411,6 +421,28 @@ def generate_qkv(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def llama_model_get_inputs_embeds(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
noisy_embeddings_alpha: Optional[int] = None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if noisy_embeddings_alpha:
|
||||||
|
input_mask = attention_mask.to(inputs_embeds) # B x L
|
||||||
|
input_lengths = torch.sum(input_mask, 1) # B
|
||||||
|
|
||||||
|
noise_ = torch.zeros_like(inputs_embeds).uniform_(-1, 1)
|
||||||
|
delta = noise_ * input_mask.unsqueeze(2)
|
||||||
|
dims = input_lengths * inputs_embeds.size(-1)
|
||||||
|
mag = noisy_embeddings_alpha / torch.sqrt(dims)
|
||||||
|
delta = (delta * mag.view(-1, 1, 1)).detach()
|
||||||
|
inputs_embeds += delta
|
||||||
|
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
|
||||||
def llama_model_forward(
|
def llama_model_forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.LongTensor = None,
|
input_ids: torch.LongTensor = None,
|
||||||
@@ -477,7 +509,8 @@ def llama_model_forward(
|
|||||||
cu_seqlens = cu_seqlens.squeeze()
|
cu_seqlens = cu_seqlens.squeeze()
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.get_inputs_embeds(input_ids, attention_mask)
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(
|
attention_mask = torch.ones(
|
||||||
|
|||||||
@@ -136,7 +136,11 @@ def load_model(
|
|||||||
|
|
||||||
replace_stablelm_attn_with_flash_attn(cfg.base_model)
|
replace_stablelm_attn_with_flash_attn(cfg.base_model)
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
|
if (
|
||||||
|
cfg.is_llama_derived_model
|
||||||
|
and cfg.flash_attention
|
||||||
|
and (cfg.noisy_embeddings_alpha or cfg.sample_packing)
|
||||||
|
):
|
||||||
if cfg.device not in ["mps", "cpu"] and not inference:
|
if cfg.device not in ["mps", "cpu"] and not inference:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
replace_llama_attn_with_flash_attn,
|
replace_llama_attn_with_flash_attn,
|
||||||
@@ -147,6 +151,7 @@ def load_model(
|
|||||||
packed=cfg.sample_packing,
|
packed=cfg.sample_packing,
|
||||||
cross_entropy=cfg.flash_attn_cross_entropy,
|
cross_entropy=cfg.flash_attn_cross_entropy,
|
||||||
rms_norm=cfg.flash_attn_rms_norm,
|
rms_norm=cfg.flash_attn_rms_norm,
|
||||||
|
noisy_embeddings_alpha=cfg.noisy_embeddings_alpha,
|
||||||
)
|
)
|
||||||
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||||
@@ -180,16 +185,16 @@ def load_model(
|
|||||||
LOG.info("patching with flash attention")
|
LOG.info("patching with flash attention")
|
||||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
|
# if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
|
||||||
from axolotl.monkeypatch.llama_embeddings_hijack import (
|
# from axolotl.monkeypatch.llama_embeddings_hijack import (
|
||||||
replace_llama_embeddings_with_uniform_distribution,
|
# replace_llama_embeddings_with_uniform_distribution,
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
LOG.info("patching with noisy embeddings")
|
# LOG.info("patching with noisy embeddings")
|
||||||
replace_llama_embeddings_with_uniform_distribution(
|
# replace_llama_embeddings_with_uniform_distribution(
|
||||||
noise_alpha=cfg.noisy_embedding_alpha
|
# noise_alpha=cfg.noisy_embedding_alpha
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
|
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
|
||||||
from axolotl.monkeypatch.mistral_embeddings_hijack import (
|
from axolotl.monkeypatch.mistral_embeddings_hijack import (
|
||||||
replace_mistral_embeddings_with_uniform_distribution,
|
replace_mistral_embeddings_with_uniform_distribution,
|
||||||
|
|||||||
Reference in New Issue
Block a user