Compare commits
4 Commits
embeddings
...
feature/at
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
956a177678 | ||
|
|
747e84d3bb | ||
|
|
c45a786039 | ||
|
|
70e6c28121 |
12
README.md
12
README.md
@@ -326,9 +326,9 @@ tokenizer_type: AutoTokenizer
|
|||||||
trust_remote_code:
|
trust_remote_code:
|
||||||
# use_fast option for tokenizer loading from_pretrained, default to True
|
# use_fast option for tokenizer loading from_pretrained, default to True
|
||||||
tokenizer_use_fast:
|
tokenizer_use_fast:
|
||||||
# resize the model embeddings when new tokens are added to multiples of N
|
# resize the model embeddings when new tokens are added to multiples of 32
|
||||||
# multiples of 32 are reported to improve training speed on some models
|
# this is reported to improve training speed on some models
|
||||||
resize_token_embeddings_multiple:
|
resize_token_embeddings_to_32x:
|
||||||
|
|
||||||
# whether you are training a 4-bit GPTQ quantized model
|
# whether you are training a 4-bit GPTQ quantized model
|
||||||
gptq: true
|
gptq: true
|
||||||
@@ -364,9 +364,6 @@ dataset_prepared_path: data/last_run_prepared
|
|||||||
push_dataset_to_hub: # repo path
|
push_dataset_to_hub: # repo path
|
||||||
# push checkpoints to hub
|
# push checkpoints to hub
|
||||||
hub_model_id: # repo path to push finetuned model
|
hub_model_id: # repo path to push finetuned model
|
||||||
# how to push checkpoints to hub
|
|
||||||
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
|
|
||||||
hub_strategy:
|
|
||||||
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||||
# required to be true when used in combination with `push_dataset_to_hub`
|
# required to be true when used in combination with `push_dataset_to_hub`
|
||||||
hf_use_auth_token: # boolean
|
hf_use_auth_token: # boolean
|
||||||
@@ -435,8 +432,7 @@ learning_rate: 0.00003
|
|||||||
logging_steps:
|
logging_steps:
|
||||||
save_steps:
|
save_steps:
|
||||||
eval_steps:
|
eval_steps:
|
||||||
save_total_limit: # checkpoints saved at a time
|
save_total_limit:
|
||||||
max_steps:
|
|
||||||
|
|
||||||
# save model as safetensors (require safetensors package)
|
# save model as safetensors (require safetensors package)
|
||||||
save_safetensors:
|
save_safetensors:
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
|||||||
|
|
||||||
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
|
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
|
||||||
cd flash-attention && \
|
cd flash-attention && \
|
||||||
git checkout v2.0.4 && \
|
git checkout v2.0.1 && \
|
||||||
python3 setup.py bdist_wheel && \
|
python3 setup.py bdist_wheel && \
|
||||||
cd csrc/fused_dense_lib && \
|
cd csrc/fused_dense_lib && \
|
||||||
python3 setup.py bdist_wheel && \
|
python3 setup.py bdist_wheel && \
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ val_set_size: 0.01
|
|||||||
output_dir: ./lora-out
|
output_dir: ./lora-out
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
max_packed_sequence_len: 4096
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
@@ -49,8 +49,8 @@ early_stopping_patience:
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention:
|
xformers_attention: true
|
||||||
flash_attention: true
|
flash_attention:
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
eval_steps: 20
|
||||||
@@ -64,3 +64,4 @@ special_tokens:
|
|||||||
bos_token: "<s>"
|
bos_token: "<s>"
|
||||||
eos_token: "</s>"
|
eos_token: "</s>"
|
||||||
unk_token: "<unk>"
|
unk_token: "<unk>"
|
||||||
|
pad_token: "<pad>"
|
||||||
|
|||||||
@@ -18,8 +18,7 @@ adapter: qlora
|
|||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
max_packed_sequence_len: 4096
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
@@ -51,8 +50,8 @@ early_stopping_patience:
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
xformers_attention:
|
xformers_attention: true
|
||||||
flash_attention: true
|
flash_attention:
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 20
|
eval_steps: 20
|
||||||
@@ -66,3 +65,4 @@ special_tokens:
|
|||||||
bos_token: "<s>"
|
bos_token: "<s>"
|
||||||
eos_token: "</s>"
|
eos_token: "</s>"
|
||||||
unk_token: "<unk>"
|
unk_token: "<unk>"
|
||||||
|
pad_token: "<pad>"
|
||||||
|
|||||||
@@ -209,13 +209,7 @@ def train(
|
|||||||
cfg, train_dataset, eval_dataset
|
cfg, train_dataset, eval_dataset
|
||||||
)
|
)
|
||||||
barrier()
|
barrier()
|
||||||
if cfg.max_steps:
|
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||||
total_num_steps = min(
|
|
||||||
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
|
|
||||||
)
|
|
||||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
|
||||||
else:
|
|
||||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
|
||||||
|
|
||||||
if cfg.debug or "debug" in kwargs:
|
if cfg.debug or "debug" in kwargs:
|
||||||
LOG.info("check_dataset_labels...")
|
LOG.info("check_dataset_labels...")
|
||||||
|
|||||||
@@ -2,142 +2,38 @@
|
|||||||
|
|
||||||
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
|
||||||
|
|
||||||
|
import warnings
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import transformers
|
import transformers
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||||
|
flash_attn_kvpacked_func,
|
||||||
|
flash_attn_varlen_kvpacked_func,
|
||||||
|
flash_attn_varlen_qkvpacked_func,
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
from flash_attn.flash_attn_interface import (
|
||||||
|
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
|
||||||
|
)
|
||||||
from flash_attn.flash_attn_interface import (
|
from flash_attn.flash_attn_interface import (
|
||||||
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
def replace_llama_attn_with_flash_attn():
|
||||||
|
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||||
|
_prepare_decoder_attention_mask
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
"""Input shape: Batch x Time x Channel
|
|
||||||
|
|
||||||
attention_mask: [bsz, q_len]
|
|
||||||
"""
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
query_states = (
|
|
||||||
self.q_proj(hidden_states)
|
|
||||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
)
|
|
||||||
key_states = (
|
|
||||||
self.k_proj(hidden_states)
|
|
||||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
)
|
|
||||||
value_states = (
|
|
||||||
self.v_proj(hidden_states)
|
|
||||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
)
|
|
||||||
# [bsz, q_len, nh, hd]
|
|
||||||
# [bsz, nh, q_len, hd]
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
assert past_key_value is None, "past_key_value is not supported"
|
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
|
||||||
query_states, key_states, cos, sin, position_ids
|
|
||||||
)
|
|
||||||
# [bsz, nh, t, hd]
|
|
||||||
assert not output_attentions, "output_attentions is not supported"
|
|
||||||
assert not use_cache, "use_cache is not supported"
|
|
||||||
|
|
||||||
# Flash attention codes from
|
|
||||||
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
|
|
||||||
|
|
||||||
# transform the data into the format required by flash attention
|
|
||||||
qkv = torch.stack(
|
|
||||||
[query_states, key_states, value_states], dim=2
|
|
||||||
) # [bsz, nh, 3, q_len, hd]
|
|
||||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
|
||||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
|
||||||
# the attention_mask should be the same as the key_padding_mask
|
|
||||||
key_padding_mask = attention_mask
|
|
||||||
|
|
||||||
if key_padding_mask is None:
|
|
||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
|
||||||
max_s = q_len
|
|
||||||
cu_q_lens = torch.arange(
|
|
||||||
0,
|
|
||||||
(bsz + 1) * q_len,
|
|
||||||
step=q_len,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=qkv.device,
|
|
||||||
)
|
|
||||||
output = flash_attn_varlen_qkvpacked_func(
|
|
||||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
|
||||||
)
|
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
|
||||||
elif attention_mask.shape[0] == 1:
|
|
||||||
# special handling using sample packing
|
|
||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
|
||||||
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
|
||||||
cu_q_lens = cu_q_lens.squeeze()
|
|
||||||
|
|
||||||
output = flash_attn_varlen_qkvpacked_func(
|
|
||||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
|
||||||
)
|
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
|
||||||
else:
|
|
||||||
nheads = qkv.shape[-2]
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
|
||||||
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
|
||||||
x_unpad = rearrange(
|
|
||||||
x_unpad,
|
|
||||||
"nnz (three h d) -> nnz three h d",
|
|
||||||
three=3,
|
|
||||||
h=nheads,
|
|
||||||
)
|
|
||||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
|
||||||
x_unpad,
|
|
||||||
cu_q_lens,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
output = rearrange(
|
|
||||||
pad_input(
|
|
||||||
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
|
||||||
indices,
|
|
||||||
bsz,
|
|
||||||
q_len,
|
|
||||||
),
|
|
||||||
"b s (h d) -> b s h d",
|
|
||||||
h=nheads,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
@@ -153,8 +49,310 @@ def _prepare_decoder_attention_mask(
|
|||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
def replace_llama_attn_with_flash_attn():
|
def flashattn_forward(
|
||||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
self,
|
||||||
_prepare_decoder_attention_mask
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""Input shape: Batch x Time x Channel
|
||||||
|
|
||||||
|
attention_mask: [bsz, q_len]
|
||||||
|
"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
if not hasattr(self, "pretraining_tp"):
|
||||||
|
self.pretraining_tp = 1
|
||||||
|
|
||||||
|
if self.pretraining_tp > 1:
|
||||||
|
key_value_slicing = (
|
||||||
|
self.num_key_value_heads * self.head_dim
|
||||||
|
) // self.pretraining_tp
|
||||||
|
query_slices = self.q_proj.weight.split(
|
||||||
|
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
||||||
|
)
|
||||||
|
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
|
||||||
|
query_states = [
|
||||||
|
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
||||||
|
]
|
||||||
|
query_states = torch.cat(query_states, dim=-1)
|
||||||
|
|
||||||
|
key_states = [
|
||||||
|
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
||||||
|
]
|
||||||
|
key_states = torch.cat(key_states, dim=-1)
|
||||||
|
|
||||||
|
value_states = [
|
||||||
|
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
||||||
|
]
|
||||||
|
value_states = torch.cat(value_states, dim=-1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
# [bsz, q_len, nh, hd]
|
||||||
|
# [bsz, nh, q_len, hd]
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids
|
||||||
|
)
|
||||||
|
# [bsz, nh, t, hd]
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
warnings.warn(
|
||||||
|
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
#
|
||||||
|
# flash-attn v2 start
|
||||||
|
#
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
# during training q,k,v always have same seqlen
|
||||||
|
assert key_states.shape == query_states.shape
|
||||||
|
is_causal = True
|
||||||
|
else:
|
||||||
|
# turn off FA causal mask after first inference autoregressive iteration
|
||||||
|
# only on first autoregressive step q,k,v have same seqlen
|
||||||
|
is_causal = past_key_value is not None
|
||||||
|
|
||||||
|
if self.training and attention_mask.shape[0] == 1:
|
||||||
|
# special handling using sample packing
|
||||||
|
qkv = torch.stack(
|
||||||
|
[query_states, key_states, value_states], dim=2
|
||||||
|
) # [bsz, nh, 3, q_len, hd]
|
||||||
|
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||||
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
||||||
|
cu_q_lens = cu_q_lens.squeeze()
|
||||||
|
|
||||||
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=is_causal
|
||||||
|
)
|
||||||
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
|
elif query_states.shape == key_states.shape:
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
qkvpacked=True,
|
||||||
|
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||||
|
# the attention_mask should be the same as the key_padding_mask
|
||||||
|
key_padding_mask=attention_mask,
|
||||||
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
|
if attention_mask is not None
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
max_seqlen_q,
|
||||||
|
0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
output = output_pad_fn(output_unpad)
|
||||||
|
else:
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
if attention_mask is None or attention_mask.all().item():
|
||||||
|
output = flash_attn_kvpacked_func(
|
||||||
|
query_states,
|
||||||
|
torch.stack([key_states, value_states], 2),
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
( # pylint: disable=unbalanced-tuple-unpacking
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
output_pad_fn,
|
||||||
|
) = generate_qkv(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
kvpacked=True,
|
||||||
|
key_padding_mask=attention_mask,
|
||||||
|
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||||
|
if attention_mask is not None
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
output = output_pad_fn(output_unpad)
|
||||||
|
|
||||||
|
attn_output = output
|
||||||
|
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||||
|
|
||||||
|
#
|
||||||
|
# flash-attn v2 end
|
||||||
|
#
|
||||||
|
|
||||||
|
if self.pretraining_tp > 1:
|
||||||
|
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
||||||
|
o_proj_slices = self.o_proj.weight.split(
|
||||||
|
self.hidden_size // self.pretraining_tp, dim=1
|
||||||
|
)
|
||||||
|
attn_output = sum(
|
||||||
|
F.linear(attn_output[i], o_proj_slices[i])
|
||||||
|
for i in range(self.pretraining_tp)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
||||||
|
def generate_qkv(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
query_padding_mask=None,
|
||||||
|
key_padding_mask=None,
|
||||||
|
kvpacked=False,
|
||||||
|
qkvpacked=False,
|
||||||
|
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
q: (batch_size, seqlen_q, nheads, d)
|
||||||
|
k: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
v: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
query_padding_mask: (batch_size, seqlen), bool
|
||||||
|
key_padding_mask: (batch_size, seqlen), bool
|
||||||
|
"""
|
||||||
|
assert not (kvpacked and qkvpacked)
|
||||||
|
batch_size, seqlen_q, nheads, d = q.shape
|
||||||
|
_, seqlen_k, nheads_k, _ = k.shape
|
||||||
|
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
|
||||||
|
if query_padding_mask is not None:
|
||||||
|
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
||||||
|
q, query_padding_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
|
||||||
|
output_unpad, indices_q, batch_size, seqlen_q
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_q,
|
||||||
|
step=seqlen_q,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=q_unpad.device,
|
||||||
|
)
|
||||||
|
max_seqlen_q = seqlen_q
|
||||||
|
|
||||||
|
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
|
||||||
|
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||||
|
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
||||||
|
else:
|
||||||
|
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||||
|
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_k = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_k,
|
||||||
|
step=seqlen_k,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=k_unpad.device,
|
||||||
|
)
|
||||||
|
max_seqlen_k = seqlen_k
|
||||||
|
|
||||||
|
if qkvpacked:
|
||||||
|
assert nheads == nheads_k
|
||||||
|
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||||
|
qkv = torch.stack([q, k, v], dim=2)
|
||||||
|
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
||||||
|
|
||||||
|
if kvpacked:
|
||||||
|
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||||
|
kv = torch.stack([k, v], dim=2)
|
||||||
|
return (
|
||||||
|
q_unpad,
|
||||||
|
kv_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q,
|
||||||
|
kv,
|
||||||
|
output_pad_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
q_unpad,
|
||||||
|
k_unpad,
|
||||||
|
v_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
output_pad_fn,
|
||||||
)
|
)
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
|
||||||
|
|||||||
140
src/axolotl/monkeypatch/llama_attn_hijack_sdp.py
Normal file
140
src/axolotl/monkeypatch/llama_attn_hijack_sdp.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
"""
|
||||||
|
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import transformers.models.llama.modeling_llama
|
||||||
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
|
|
||||||
|
def hijack_llama_sdp_attention():
|
||||||
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||||
|
sdp_attention_forward
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sdp_attention_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
if not hasattr(self, "pretraining_tp"):
|
||||||
|
self.pretraining_tp = 1
|
||||||
|
|
||||||
|
if self.pretraining_tp > 1:
|
||||||
|
key_value_slicing = (
|
||||||
|
self.num_key_value_heads * self.head_dim
|
||||||
|
) // self.pretraining_tp
|
||||||
|
query_slices = self.q_proj.weight.split(
|
||||||
|
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
||||||
|
)
|
||||||
|
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||||
|
|
||||||
|
query_states = [
|
||||||
|
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
||||||
|
]
|
||||||
|
query_states = torch.cat(query_states, dim=-1)
|
||||||
|
|
||||||
|
key_states = [
|
||||||
|
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
||||||
|
]
|
||||||
|
key_states = torch.cat(key_states, dim=-1)
|
||||||
|
|
||||||
|
value_states = [
|
||||||
|
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
||||||
|
]
|
||||||
|
value_states = torch.cat(value_states, dim=-1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
# [bsz, q_len, nh, hd]
|
||||||
|
# [bsz, nh, q_len, hd]
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids
|
||||||
|
)
|
||||||
|
# [bsz, nh, t, hd]
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# reuse k, v, self_attention
|
||||||
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
warnings.warn(
|
||||||
|
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
#
|
||||||
|
# sdp-attn start
|
||||||
|
#
|
||||||
|
|
||||||
|
with torch.backends.cuda.sdp_kernel():
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
is_causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
#
|
||||||
|
# sdp-attn end
|
||||||
|
#
|
||||||
|
|
||||||
|
if self.pretraining_tp > 1:
|
||||||
|
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
||||||
|
o_proj_slices = self.o_proj.weight.split(
|
||||||
|
self.hidden_size // self.pretraining_tp, dim=1
|
||||||
|
)
|
||||||
|
attn_output = sum(
|
||||||
|
F.linear(attn_output[i], o_proj_slices[i])
|
||||||
|
for i in range(self.pretraining_tp)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output, None, past_key_value
|
||||||
@@ -3,13 +3,13 @@ Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-g
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import warnings
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import transformers.models.llama.modeling_llama
|
import transformers.models.llama.modeling_llama
|
||||||
from torch import nn
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
@@ -21,12 +21,6 @@ def hijack_llama_attention():
|
|||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
|
||||||
|
|
||||||
|
|
||||||
def hijack_llama_sdp_attention():
|
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
|
||||||
sdp_attention_forward
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def xformers_forward(
|
def xformers_forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -81,15 +75,15 @@ def xformers_forward(
|
|||||||
value_states = value_states.view(
|
value_states = value_states.view(
|
||||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
|
# [bsz, q_len, nh, hd]
|
||||||
|
# [bsz, nh, q_len, hd]
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
(
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
|
|
||||||
query_states, key_states, cos, sin, position_ids
|
query_states, key_states, cos, sin, position_ids
|
||||||
)
|
)
|
||||||
# [bsz, nh, t, hd]
|
# [bsz, nh, t, hd]
|
||||||
@@ -102,74 +96,50 @@ def xformers_forward(
|
|||||||
past_key_value = (key_states, value_states) if use_cache else None
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
# repeat k/v heads if n_kv_heads < n_heads
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
key_states = transformers.models.llama.modeling_llama.repeat_kv(
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
key_states, self.num_key_value_groups
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
)
|
|
||||||
value_states = transformers.models.llama.modeling_llama.repeat_kv(
|
|
||||||
value_states, self.num_key_value_groups
|
|
||||||
)
|
|
||||||
|
|
||||||
# We only apply xformers optimizations if we don't need to output the whole attention matrix
|
if output_attentions:
|
||||||
if not output_attentions:
|
warnings.warn(
|
||||||
query_states = query_states.transpose(1, 2)
|
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
||||||
key_states = key_states.transpose(1, 2)
|
)
|
||||||
value_states = value_states.transpose(1, 2)
|
|
||||||
|
|
||||||
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
#
|
||||||
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
# xformers-attn start
|
||||||
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
#
|
||||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
|
||||||
attn_output = xformers.ops.memory_efficient_attention(
|
query_states = query_states.transpose(1, 2)
|
||||||
query_states, key_states, value_states, attn_bias=None
|
key_states = key_states.transpose(1, 2)
|
||||||
)
|
value_states = value_states.transpose(1, 2)
|
||||||
else:
|
|
||||||
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
|
||||||
attn_output = xformers.ops.memory_efficient_attention(
|
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
|
||||||
query_states,
|
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
|
||||||
key_states,
|
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||||
value_states,
|
attn_output = xformers.ops.memory_efficient_attention(
|
||||||
# attn_bias=attention_mask,
|
query_states, key_states, value_states, attn_bias=None
|
||||||
attn_bias=xformers.ops.LowerTriangularMask(),
|
)
|
||||||
)
|
|
||||||
attn_weights = None
|
|
||||||
else:
|
else:
|
||||||
attn_weights = torch.matmul(
|
# input and output should be of form (bsz, q_len, num_heads, head_dim)
|
||||||
query_states, key_states.transpose(2, 3)
|
attn_output = xformers.ops.memory_efficient_attention(
|
||||||
) / math.sqrt(self.head_dim)
|
query_states,
|
||||||
|
key_states,
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
value_states,
|
||||||
raise ValueError(
|
# attn_bias=attention_mask,
|
||||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
attn_bias=xformers.ops.LowerTriangularMask(),
|
||||||
f" {attn_weights.size()}"
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
||||||
raise ValueError(
|
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
||||||
)
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
attn_weights = torch.max(
|
|
||||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
|
||||||
)
|
|
||||||
|
|
||||||
# upcast attention to fp32
|
|
||||||
attn_weights = nn.functional.softmax(
|
|
||||||
attn_weights, dim=-1, dtype=torch.float32
|
|
||||||
).to(query_states.dtype)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
|
|
||||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
||||||
raise ValueError(
|
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
|
||||||
f" {attn_output.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
||||||
# end x-formers vs. not x-formers if-else block
|
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||||
|
|
||||||
|
#
|
||||||
|
# xformers-attn end
|
||||||
|
#
|
||||||
|
|
||||||
if self.pretraining_tp > 1:
|
if self.pretraining_tp > 1:
|
||||||
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
||||||
o_proj_slices = self.o_proj.weight.split(
|
o_proj_slices = self.o_proj.weight.split(
|
||||||
@@ -182,103 +152,4 @@ def xformers_forward(
|
|||||||
else:
|
else:
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
def sdp_attention_forward(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
use_cache: bool = False,
|
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
query_states = (
|
|
||||||
self.q_proj(hidden_states)
|
|
||||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
)
|
|
||||||
key_states = (
|
|
||||||
self.k_proj(hidden_states)
|
|
||||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
)
|
|
||||||
value_states = (
|
|
||||||
self.v_proj(hidden_states)
|
|
||||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
|
||||||
.transpose(1, 2)
|
|
||||||
)
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is not None:
|
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
|
|
||||||
query_states, key_states, cos, sin, position_ids
|
|
||||||
)
|
|
||||||
# [bsz, nh, t, hd]
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
# reuse k, v, self_attention
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
|
||||||
|
|
||||||
# We only apply sdp attention if we don't need to output the whole attention matrix
|
|
||||||
if not output_attentions:
|
|
||||||
with torch.backends.cuda.sdp_kernel():
|
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
value_states,
|
|
||||||
attn_mask=attention_mask,
|
|
||||||
is_causal=False,
|
|
||||||
)
|
|
||||||
attn_weights = None
|
|
||||||
else:
|
|
||||||
attn_weights = torch.matmul(
|
|
||||||
query_states, key_states.transpose(2, 3)
|
|
||||||
) / math.sqrt(self.head_dim)
|
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
||||||
raise ValueError(
|
|
||||||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
|
|
||||||
f" {attn_weights.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
||||||
raise ValueError(
|
|
||||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
||||||
)
|
|
||||||
attn_weights = attn_weights + attention_mask
|
|
||||||
attn_weights = torch.max(
|
|
||||||
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
|
|
||||||
)
|
|
||||||
|
|
||||||
# upcast attention to fp32
|
|
||||||
attn_weights = nn.functional.softmax(
|
|
||||||
attn_weights, dim=-1, dtype=torch.float32
|
|
||||||
).to(query_states.dtype)
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
|
||||||
|
|
||||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
||||||
raise ValueError(
|
|
||||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
|
||||||
f" {attn_output.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = attn_output.transpose(1, 2)
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
||||||
|
|
||||||
attn_output = self.o_proj(attn_output)
|
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
|
||||||
|
|||||||
@@ -312,9 +312,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|||||||
if len(source) < 2:
|
if len(source) < 2:
|
||||||
# If there isn't a back and forth conversation, ignore it
|
# If there isn't a back and forth conversation, ignore it
|
||||||
# also happens on the data splitting leaving empty conversations
|
# also happens on the data splitting leaving empty conversations
|
||||||
raise IndexError(
|
raise IndexError
|
||||||
f"A conversation entry has less than 2 messages :\n{source}"
|
|
||||||
)
|
|
||||||
|
|
||||||
conv = self._conversation.copy()
|
conv = self._conversation.copy()
|
||||||
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
||||||
|
|||||||
@@ -28,9 +28,6 @@ def gpu_memory_usage_smi(device=0):
|
|||||||
|
|
||||||
|
|
||||||
def log_gpu_memory_usage(log, msg, device):
|
def log_gpu_memory_usage(log, msg, device):
|
||||||
if not torch.cuda.is_available():
|
|
||||||
return (0, 0, 0)
|
|
||||||
|
|
||||||
usage, cache, misc = gpu_memory_usage_all(device)
|
usage, cache, misc = gpu_memory_usage_all(device)
|
||||||
extras = []
|
extras = []
|
||||||
if cache > 0:
|
if cache > 0:
|
||||||
|
|||||||
@@ -32,45 +32,6 @@ if TYPE_CHECKING:
|
|||||||
from axolotl.utils.dict import DictDefault # noqa: F401
|
from axolotl.utils.dict import DictDefault # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
def smart_tokenizer_and_embedding_resize(
|
|
||||||
tokenizer: transformers.PreTrainedTokenizer,
|
|
||||||
model: transformers.PreTrainedModel,
|
|
||||||
resize_token_embeddings_multiple: Optional[int] = None,
|
|
||||||
):
|
|
||||||
"""Resize tokenizer and embedding.
|
|
||||||
|
|
||||||
Note: This function resizes the tokenizer to accommodate additional special tokens and the
|
|
||||||
embedding matrix of the model to match the new size of the tokenizer. If any new special tokens
|
|
||||||
have been added, the function computes the average embedding values of the existing embeddings
|
|
||||||
and sets those values for the new special token embeddings. This is done separately for the input
|
|
||||||
embeddings and output embeddings of the model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
old_tokens = model.get_input_embeddings().weight.data.shape[0]
|
|
||||||
num_new_tokens = len(tokenizer) - old_tokens
|
|
||||||
embeddings_len = (
|
|
||||||
math.ceil(len(tokenizer) / resize_token_embeddings_multiple)
|
|
||||||
* resize_token_embeddings_multiple
|
|
||||||
if resize_token_embeddings_multiple
|
|
||||||
else len(tokenizer)
|
|
||||||
)
|
|
||||||
model.resize_token_embeddings(embeddings_len)
|
|
||||||
|
|
||||||
if num_new_tokens > 0:
|
|
||||||
input_embeddings = model.get_input_embeddings().weight.data
|
|
||||||
output_embeddings = model.get_output_embeddings().weight.data
|
|
||||||
|
|
||||||
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
|
||||||
dim=0, keepdim=True
|
|
||||||
)
|
|
||||||
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
|
||||||
dim=0, keepdim=True
|
|
||||||
)
|
|
||||||
|
|
||||||
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
|
||||||
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(cfg):
|
def load_tokenizer(cfg):
|
||||||
tokenizer_kwargs = {}
|
tokenizer_kwargs = {}
|
||||||
use_fast = True # this is the default
|
use_fast = True # this is the default
|
||||||
@@ -151,9 +112,7 @@ def load_model(
|
|||||||
LOG.info("patching with xformers attention")
|
LOG.info("patching with xformers attention")
|
||||||
hijack_llama_attention()
|
hijack_llama_attention()
|
||||||
elif cfg.is_llama_derived_model and cfg.sdp_attention:
|
elif cfg.is_llama_derived_model and cfg.sdp_attention:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention
|
||||||
hijack_llama_sdp_attention,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("patching with sdp attention")
|
LOG.info("patching with sdp attention")
|
||||||
hijack_llama_sdp_attention()
|
hijack_llama_sdp_attention()
|
||||||
@@ -268,12 +227,8 @@ def load_model(
|
|||||||
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
||||||
from transformers import LlamaForCausalLM
|
from transformers import LlamaForCausalLM
|
||||||
|
|
||||||
config_kwargs = {}
|
|
||||||
if cfg.rope_scaling:
|
|
||||||
config_kwargs["rope_scaling"] = cfg.rope_scaling
|
|
||||||
config = LlamaConfig.from_pretrained(
|
config = LlamaConfig.from_pretrained(
|
||||||
base_model_config,
|
base_model_config, rope_scaling=cfg.rope_scaling
|
||||||
**config_kwargs,
|
|
||||||
)
|
)
|
||||||
model = LlamaForCausalLM.from_pretrained(
|
model = LlamaForCausalLM.from_pretrained(
|
||||||
base_model,
|
base_model,
|
||||||
@@ -366,16 +321,17 @@ def load_model(
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
smart_tokenizer_and_embedding_resize(
|
embeddings_len = (
|
||||||
tokenizer,
|
math.ceil(len(tokenizer) / 32) * 32
|
||||||
model,
|
if cfg.resize_token_embeddings_to_32x
|
||||||
resize_token_embeddings_multiple=cfg.resize_token_embeddings_multiple,
|
else len(tokenizer)
|
||||||
)
|
)
|
||||||
|
model.resize_token_embeddings(embeddings_len)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model.config, "max_position_embeddings")
|
hasattr(model.config, "max_position_embeddings")
|
||||||
and model.config.max_position_embeddings
|
and model.config.max_position_embeddings
|
||||||
and cfg.sequence_len > model.config.max_position_embeddings
|
and cfg.sequence_len >= model.config.max_position_embeddings
|
||||||
):
|
):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
|
||||||
|
|||||||
@@ -440,9 +440,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
training_arguments_kwargs["push_to_hub"] = True
|
training_arguments_kwargs["push_to_hub"] = True
|
||||||
training_arguments_kwargs["hub_private_repo"] = True
|
training_arguments_kwargs["hub_private_repo"] = True
|
||||||
|
|
||||||
if cfg.hub_strategy:
|
|
||||||
training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy
|
|
||||||
|
|
||||||
if cfg.save_safetensors:
|
if cfg.save_safetensors:
|
||||||
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
||||||
|
|
||||||
@@ -451,17 +448,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
"sample_packing_efficiency"
|
"sample_packing_efficiency"
|
||||||
] = cfg.sample_packing_eff_est
|
] = cfg.sample_packing_eff_est
|
||||||
|
|
||||||
if cfg.val_set_size == 0:
|
|
||||||
evaluation_strategy = "no"
|
|
||||||
elif cfg.eval_steps < 1:
|
|
||||||
# eval every epoch
|
|
||||||
evaluation_strategy = "epoch"
|
|
||||||
else:
|
|
||||||
# eval every eval_steps steps
|
|
||||||
evaluation_strategy = "steps"
|
|
||||||
|
|
||||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
max_steps=total_num_steps if cfg.max_steps else -1,
|
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
|
||||||
max_seq_length=cfg.sequence_len,
|
max_seq_length=cfg.sequence_len,
|
||||||
per_device_train_batch_size=cfg.micro_batch_size,
|
per_device_train_batch_size=cfg.micro_batch_size,
|
||||||
per_device_eval_batch_size=cfg.eval_batch_size
|
per_device_eval_batch_size=cfg.eval_batch_size
|
||||||
@@ -471,7 +459,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||||
num_train_epochs=cfg.num_epochs,
|
num_train_epochs=cfg.num_epochs,
|
||||||
learning_rate=cfg.learning_rate,
|
learning_rate=cfg.learning_rate,
|
||||||
evaluation_strategy=evaluation_strategy,
|
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
|
||||||
save_strategy="steps" if cfg.save_steps else "epoch",
|
save_strategy="steps" if cfg.save_steps else "epoch",
|
||||||
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
||||||
save_steps=cfg.save_steps,
|
save_steps=cfg.save_steps,
|
||||||
|
|||||||
Reference in New Issue
Block a user