Compare commits

..

4 Commits

Author SHA1 Message Date
Aman Karmani
956a177678 speed up flash-attn inference
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-08-13 18:03:38 +00:00
Aman Karmani
747e84d3bb update flash-attn patch for 70B/GQA and inference using helper from flash-attn tests 2023-08-13 15:41:44 +00:00
Aman Karmani
c45a786039 sync xformers patch to follow shared format and be diffable 2023-08-13 15:41:06 +00:00
Aman Karmani
70e6c28121 split sdp attn into its own patch 2023-08-13 15:40:43 +00:00
12 changed files with 533 additions and 394 deletions

View File

@@ -326,9 +326,9 @@ tokenizer_type: AutoTokenizer
trust_remote_code:
# use_fast option for tokenizer loading from_pretrained, default to True
tokenizer_use_fast:
# resize the model embeddings when new tokens are added to multiples of N
# multiples of 32 are reported to improve training speed on some models
resize_token_embeddings_multiple:
# resize the model embeddings when new tokens are added to multiples of 32
# this is reported to improve training speed on some models
resize_token_embeddings_to_32x:
# whether you are training a 4-bit GPTQ quantized model
gptq: true
@@ -364,9 +364,6 @@ dataset_prepared_path: data/last_run_prepared
push_dataset_to_hub: # repo path
# push checkpoints to hub
hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy:
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean
@@ -435,8 +432,7 @@ learning_rate: 0.00003
logging_steps:
save_steps:
eval_steps:
save_total_limit: # checkpoints saved at a time
max_steps:
save_total_limit:
# save model as safetensors (require safetensors package)
save_safetensors:

View File

@@ -40,7 +40,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
cd flash-attention && \
git checkout v2.0.4 && \
git checkout v2.0.1 && \
python3 setup.py bdist_wheel && \
cd csrc/fused_dense_lib && \
python3 setup.py bdist_wheel && \

View File

@@ -15,7 +15,7 @@ val_set_size: 0.01
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
max_packed_sequence_len: 4096
adapter: lora
lora_model_dir:
@@ -49,8 +49,8 @@ early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
xformers_attention: true
flash_attention:
warmup_steps: 10
eval_steps: 20
@@ -64,3 +64,4 @@ special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
pad_token: "<pad>"

View File

@@ -18,8 +18,7 @@ adapter: qlora
lora_model_dir:
sequence_len: 4096
sample_packing: true
max_packed_sequence_len: 4096
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
@@ -51,8 +50,8 @@ early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
xformers_attention: true
flash_attention:
warmup_steps: 10
eval_steps: 20
@@ -66,3 +65,4 @@ special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"
pad_token: "<pad>"

View File

@@ -209,13 +209,7 @@ def train(
cfg, train_dataset, eval_dataset
)
barrier()
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
)
LOG.info(f"Maximum number of steps set at {total_num_steps}")
else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
if cfg.debug or "debug" in kwargs:
LOG.info("check_dataset_labels...")

View File

@@ -2,142 +2,38 @@
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
import warnings
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
flash_attn_kvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
except ImportError:
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
)
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
)
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
assert past_key_value is None, "past_key_value is not supported"
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
assert not output_attentions, "output_attentions is not supported"
assert not use_cache, "use_cache is not supported"
# Flash attention codes from
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
# transform the data into the format required by flash attention
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask = attention_mask
if key_padding_mask is None:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len
cu_q_lens = torch.arange(
0,
(bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=qkv.device,
)
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif attention_mask.shape[0] == 1:
# special handling using sample packing
qkv = rearrange(qkv, "b s ... -> (b s) ...")
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
cu_q_lens = cu_q_lens.squeeze()
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
# pylint: disable=invalid-name
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad,
"nnz (three h d) -> nnz three h d",
three=3,
h=nheads,
)
output_unpad = flash_attn_varlen_qkvpacked_func(
x_unpad,
cu_q_lens,
max_s,
0.0,
softmax_scale=None,
causal=True,
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
indices,
bsz,
q_len,
),
"b s (h d) -> b s h d",
h=nheads,
)
return (
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
None,
None,
def replace_llama_attn_with_flash_attn():
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
# Disable the transformation of the attention mask in LlamaModel as the flash attention
@@ -153,8 +49,310 @@ def _prepare_decoder_attention_mask(
return attention_mask
def replace_llama_attn_with_flash_attn():
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
def flashattn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
#
# flash-attn v2 start
#
if self.training:
# during training q,k,v always have same seqlen
assert key_states.shape == query_states.shape
is_causal = True
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = past_key_value is not None
if self.training and attention_mask.shape[0] == 1:
# special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
cu_q_lens = cu_q_lens.squeeze()
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=is_causal
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
query_states,
key_states,
value_states,
qkvpacked=True,
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
)
output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad,
cu_seqlens_q,
max_seqlen_q,
0.0,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if attention_mask is None or attention_mask.all().item():
output = flash_attn_kvpacked_func(
query_states,
torch.stack([key_states, value_states], 2),
causal=is_causal,
)
else:
( # pylint: disable=unbalanced-tuple-unpacking
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
_,
_,
output_pad_fn,
) = generate_qkv(
query_states,
key_states,
value_states,
kvpacked=True,
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
)
output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
attn_output = output
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
#
# flash-attn v2 end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
def generate_qkv(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False,
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
q, query_padding_mask
)
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q_unpad.device,
)
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=k_unpad.device,
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
if kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
return (
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
kv,
output_pad_fn,
)
return (
q_unpad,
k_unpad,
v_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward

View File

@@ -0,0 +1,140 @@
"""
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
"""
import warnings
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import transformers.models.llama.modeling_llama
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
def hijack_llama_sdp_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
sdp_attention_forward
)
def sdp_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
#
# sdp-attn start
#
with torch.backends.cuda.sdp_kernel():
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
#
# sdp-attn end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value

View File

@@ -3,13 +3,13 @@ Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-g
"""
import logging
import math
import warnings
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import transformers.models.llama.modeling_llama
from torch import nn
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
try:
import xformers.ops
@@ -21,12 +21,6 @@ def hijack_llama_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
def hijack_llama_sdp_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
sdp_attention_forward
)
def xformers_forward(
self,
hidden_states: torch.Tensor,
@@ -81,15 +75,15 @@ def xformers_forward(
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
(
query_states,
key_states,
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
@@ -102,74 +96,50 @@ def xformers_forward(
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = transformers.models.llama.modeling_llama.repeat_kv(
key_states, self.num_key_value_groups
)
value_states = transformers.models.llama.modeling_llama.repeat_kv(
value_states, self.num_key_value_groups
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# We only apply xformers optimizations if we don't need to output the whole attention matrix
if not output_attentions:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=None
)
else:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states,
key_states,
value_states,
# attn_bias=attention_mask,
attn_bias=xformers.ops.LowerTriangularMask(),
)
attn_weights = None
#
# xformers-attn start
#
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=None
)
else:
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
# end x-formers vs. not x-formers if-else block
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states,
key_states,
value_states,
# attn_bias=attention_mask,
attn_bias=xformers.ops.LowerTriangularMask(),
)
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
#
# xformers-attn end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
@@ -182,103 +152,4 @@ def xformers_forward(
else:
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
def sdp_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
(
query_states,
key_states,
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# We only apply sdp attention if we don't need to output the whole attention matrix
if not output_attentions:
with torch.backends.cuda.sdp_kernel():
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
)
attn_weights = None
else:
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
return attn_output, None, past_key_value

View File

@@ -312,9 +312,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
if len(source) < 2:
# If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations
raise IndexError(
f"A conversation entry has less than 2 messages :\n{source}"
)
raise IndexError
conv = self._conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

View File

@@ -28,9 +28,6 @@ def gpu_memory_usage_smi(device=0):
def log_gpu_memory_usage(log, msg, device):
if not torch.cuda.is_available():
return (0, 0, 0)
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
if cache > 0:

View File

@@ -32,45 +32,6 @@ if TYPE_CHECKING:
from axolotl.utils.dict import DictDefault # noqa: F401
def smart_tokenizer_and_embedding_resize(
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
resize_token_embeddings_multiple: Optional[int] = None,
):
"""Resize tokenizer and embedding.
Note: This function resizes the tokenizer to accommodate additional special tokens and the
embedding matrix of the model to match the new size of the tokenizer. If any new special tokens
have been added, the function computes the average embedding values of the existing embeddings
and sets those values for the new special token embeddings. This is done separately for the input
embeddings and output embeddings of the model.
"""
old_tokens = model.get_input_embeddings().weight.data.shape[0]
num_new_tokens = len(tokenizer) - old_tokens
embeddings_len = (
math.ceil(len(tokenizer) / resize_token_embeddings_multiple)
* resize_token_embeddings_multiple
if resize_token_embeddings_multiple
else len(tokenizer)
)
model.resize_token_embeddings(embeddings_len)
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def load_tokenizer(cfg):
tokenizer_kwargs = {}
use_fast = True # this is the default
@@ -151,9 +112,7 @@ def load_model(
LOG.info("patching with xformers attention")
hijack_llama_attention()
elif cfg.is_llama_derived_model and cfg.sdp_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_sdp_attention,
)
from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention
LOG.info("patching with sdp attention")
hijack_llama_sdp_attention()
@@ -268,12 +227,8 @@ def load_model(
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
from transformers import LlamaForCausalLM
config_kwargs = {}
if cfg.rope_scaling:
config_kwargs["rope_scaling"] = cfg.rope_scaling
config = LlamaConfig.from_pretrained(
base_model_config,
**config_kwargs,
base_model_config, rope_scaling=cfg.rope_scaling
)
model = LlamaForCausalLM.from_pretrained(
base_model,
@@ -366,16 +321,17 @@ def load_model(
**model_kwargs,
)
smart_tokenizer_and_embedding_resize(
tokenizer,
model,
resize_token_embeddings_multiple=cfg.resize_token_embeddings_multiple,
embeddings_len = (
math.ceil(len(tokenizer) / 32) * 32
if cfg.resize_token_embeddings_to_32x
else len(tokenizer)
)
model.resize_token_embeddings(embeddings_len)
if (
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings
and cfg.sequence_len > model.config.max_position_embeddings
and cfg.sequence_len >= model.config.max_position_embeddings
):
LOG.warning(
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"

View File

@@ -440,9 +440,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True
if cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy
if cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
@@ -451,17 +448,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
"sample_packing_efficiency"
] = cfg.sample_packing_eff_est
if cfg.val_set_size == 0:
evaluation_strategy = "no"
elif cfg.eval_steps < 1:
# eval every epoch
evaluation_strategy = "epoch"
else:
# eval every eval_steps steps
evaluation_strategy = "steps"
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps if cfg.max_steps else -1,
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
max_seq_length=cfg.sequence_len,
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size
@@ -471,7 +459,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate,
evaluation_strategy=evaluation_strategy,
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no",
save_strategy="steps" if cfg.save_steps else "epoch",
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
save_steps=cfg.save_steps,