Compare commits

..

10 Commits

Author SHA1 Message Date
Wing Lian
31079cd5fd smart resize embeddings
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-08-14 23:44:15 -04:00
NanoCode012
41ecb451c2 Feat(doc): Add max_steps to readme (#389) 2023-08-15 00:34:22 +09:00
Gabriel Puliatti
3c2ad00d07 Feat(config): add max steps (#387) 2023-08-14 11:19:29 -04:00
florian peyron
5d48a10548 Added "epoch" evaluation_strategy (#388) 2023-08-14 10:59:23 -04:00
NanoCode012
73a0b6ead5 Feat(config): Add hub_strategy (#386) 2023-08-14 07:12:55 -04:00
florian peyron
63fdb5a7fb Error msg for sharegpt if conv has less than 2 msg (#379) 2023-08-14 17:40:40 +09:00
mhenrichsen
fdffef5940 new llama-2 default settings (#370)
* new default settings

* fix whitespace

* rm max packed sequence length

---------

Co-authored-by: Mads Henrichsen <mads@BrbartiendeMads.lan>
2023-08-14 17:39:09 +09:00
Wing Lian
919246fbc1 don't pass rope_scaling kwarg if it's None (#383) 2023-08-13 18:57:38 -04:00
Wing Lian
ffac902c1b bump flash-attn to 2.0.4 for the base docker image (#382) 2023-08-13 17:55:04 -04:00
Charles Goddard
15f6e57eaa Fix crash when running without CUDA 2023-08-13 13:36:40 -07:00
12 changed files with 363 additions and 502 deletions

View File

@@ -326,9 +326,9 @@ tokenizer_type: AutoTokenizer
trust_remote_code: trust_remote_code:
# use_fast option for tokenizer loading from_pretrained, default to True # use_fast option for tokenizer loading from_pretrained, default to True
tokenizer_use_fast: tokenizer_use_fast:
# resize the model embeddings when new tokens are added to multiples of 32 # resize the model embeddings when new tokens are added to multiples of N
# this is reported to improve training speed on some models # multiples of 32 are reported to improve training speed on some models
resize_token_embeddings_to_32x: resize_token_embeddings_multiple:
# whether you are training a 4-bit GPTQ quantized model # whether you are training a 4-bit GPTQ quantized model
gptq: true gptq: true
@@ -364,6 +364,9 @@ dataset_prepared_path: data/last_run_prepared
push_dataset_to_hub: # repo path push_dataset_to_hub: # repo path
# push checkpoints to hub # push checkpoints to hub
hub_model_id: # repo path to push finetuned model hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy:
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets # whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# required to be true when used in combination with `push_dataset_to_hub` # required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean hf_use_auth_token: # boolean
@@ -432,7 +435,8 @@ learning_rate: 0.00003
logging_steps: logging_steps:
save_steps: save_steps:
eval_steps: eval_steps:
save_total_limit: save_total_limit: # checkpoints saved at a time
max_steps:
# save model as safetensors (require safetensors package) # save model as safetensors (require safetensors package)
save_safetensors: save_safetensors:

View File

@@ -40,7 +40,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \ RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
cd flash-attention && \ cd flash-attention && \
git checkout v2.0.1 && \ git checkout v2.0.4 && \
python3 setup.py bdist_wheel && \ python3 setup.py bdist_wheel && \
cd csrc/fused_dense_lib && \ cd csrc/fused_dense_lib && \
python3 setup.py bdist_wheel && \ python3 setup.py bdist_wheel && \

View File

@@ -15,7 +15,7 @@ val_set_size: 0.01
output_dir: ./lora-out output_dir: ./lora-out
sequence_len: 4096 sequence_len: 4096
max_packed_sequence_len: 4096 sample_packing: true
adapter: lora adapter: lora
lora_model_dir: lora_model_dir:
@@ -49,8 +49,8 @@ early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
logging_steps: 1 logging_steps: 1
xformers_attention: true xformers_attention:
flash_attention: flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 20 eval_steps: 20
@@ -64,4 +64,3 @@ special_tokens:
bos_token: "<s>" bos_token: "<s>"
eos_token: "</s>" eos_token: "</s>"
unk_token: "<unk>" unk_token: "<unk>"
pad_token: "<pad>"

View File

@@ -18,7 +18,8 @@ adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 4096 sequence_len: 4096
max_packed_sequence_len: 4096 sample_packing: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05
@@ -50,8 +51,8 @@ early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
logging_steps: 1 logging_steps: 1
xformers_attention: true xformers_attention:
flash_attention: flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 20 eval_steps: 20
@@ -65,4 +66,3 @@ special_tokens:
bos_token: "<s>" bos_token: "<s>"
eos_token: "</s>" eos_token: "</s>"
unk_token: "<unk>" unk_token: "<unk>"
pad_token: "<pad>"

View File

@@ -209,7 +209,13 @@ def train(
cfg, train_dataset, eval_dataset cfg, train_dataset, eval_dataset
) )
barrier() barrier()
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer) if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
)
LOG.info(f"Maximum number of steps set at {total_num_steps}")
else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
if cfg.debug or "debug" in kwargs: if cfg.debug or "debug" in kwargs:
LOG.info("check_dataset_labels...") LOG.info("check_dataset_labels...")

View File

@@ -2,54 +2,26 @@
# copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py # copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
import warnings
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F
import transformers import transformers
from einops import rearrange from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.bert_padding import pad_input, unpad_input
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
try: try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
flash_attn_kvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func,
)
except ImportError: except ImportError:
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func,
)
from flash_attn.flash_attn_interface import ( from flash_attn.flash_attn_interface import (
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func, flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
) )
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
def replace_llama_attn_with_flash_attn(): from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
# Disable the transformation of the attention mask in LlamaModel as the flash attention def forward(
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self,
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
): # pylint: disable=unused-argument
# [bsz, seq_len]
return attention_mask
def flashattn_forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
@@ -65,294 +37,124 @@ def flashattn_forward(
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"): query_states = (
self.pretraining_tp = 1 self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
if self.pretraining_tp > 1: .transpose(1, 2)
key_value_slicing = ( )
self.num_key_value_heads * self.head_dim key_states = (
) // self.pretraining_tp self.k_proj(hidden_states)
query_slices = self.q_proj.weight.split( .view(bsz, q_len, self.num_heads, self.head_dim)
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0 .transpose(1, 2)
) )
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) value_states = (
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
query_states = [ .transpose(1, 2)
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp) )
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd] # [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd] # [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: assert past_key_value is None, "past_key_value is not supported"
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb( query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids query_states, key_states, cos, sin, position_ids
) )
# [bsz, nh, t, hd] # [bsz, nh, t, hd]
assert not output_attentions, "output_attentions is not supported"
assert not use_cache, "use_cache is not supported"
if past_key_value is not None: # Flash attention codes from
# reuse k, v, self_attention # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None # transform the data into the format required by flash attention
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask = attention_mask
# repeat k/v heads if n_kv_heads < n_heads if key_padding_mask is None:
key_states = repeat_kv(key_states, self.num_key_value_groups) qkv = rearrange(qkv, "b s ... -> (b s) ...")
value_states = repeat_kv(value_states, self.num_key_value_groups) max_s = q_len
cu_q_lens = torch.arange(
if output_attentions: 0,
warnings.warn( (bsz + 1) * q_len,
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." step=q_len,
dtype=torch.int32,
device=qkv.device,
) )
output = flash_attn_varlen_qkvpacked_func(
# qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
# flash-attn v2 start )
# output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif attention_mask.shape[0] == 1:
if self.training:
# during training q,k,v always have same seqlen
assert key_states.shape == query_states.shape
is_causal = True
else:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal = past_key_value is not None
if self.training and attention_mask.shape[0] == 1:
# special handling using sample packing # special handling using sample packing
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
qkv = rearrange(qkv, "b s ... -> (b s) ...") qkv = rearrange(qkv, "b s ... -> (b s) ...")
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids) cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
cu_q_lens = cu_q_lens.squeeze() cu_q_lens = cu_q_lens.squeeze()
output = flash_attn_varlen_qkvpacked_func( output = flash_attn_varlen_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=is_causal qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape: else:
query_states = query_states.transpose(1, 2) nheads = qkv.shape[-2]
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2) # pylint: disable=invalid-name
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv( x = rearrange(qkv, "b s three h d -> b s (three h d)")
query_states, x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
key_states, x_unpad = rearrange(
value_states, x_unpad,
qkvpacked=True, "nnz (three h d) -> nnz three h d",
# We have disabled _prepare_decoder_attention_mask in LlamaModel three=3,
# the attention_mask should be the same as the key_padding_mask h=nheads,
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
) )
output_unpad = flash_attn_varlen_qkvpacked_func( output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad, x_unpad,
cu_seqlens_q, cu_q_lens,
max_seqlen_q, max_s,
0.0, 0.0,
softmax_scale=None, softmax_scale=None,
causal=is_causal, causal=True,
) )
output = output_pad_fn(output_unpad) output = rearrange(
else: pad_input(
query_states = query_states.transpose(1, 2) rearrange(output_unpad, "nnz h d -> nnz (h d)"),
key_states = key_states.transpose(1, 2) indices,
value_states = value_states.transpose(1, 2) bsz,
if attention_mask is None or attention_mask.all().item(): q_len,
output = flash_attn_kvpacked_func( ),
query_states, "b s (h d) -> b s h d",
torch.stack([key_states, value_states], 2), h=nheads,
causal=is_causal,
)
else:
( # pylint: disable=unbalanced-tuple-unpacking
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
_,
_,
output_pad_fn,
) = generate_qkv(
query_states,
key_states,
value_states,
kvpacked=True,
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
)
output_unpad = flash_attn_varlen_kvpacked_func(
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
0.0,
softmax_scale=None,
causal=is_causal,
)
output = output_pad_fn(output_unpad)
attn_output = output
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
#
# flash-attn v2 end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
def generate_qkv(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False,
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
"""
Arguments:
q: (batch_size, seqlen_q, nheads, d)
k: (batch_size, seqlen_k, nheads_k, d)
v: (batch_size, seqlen_k, nheads_k, d)
query_padding_mask: (batch_size, seqlen), bool
key_padding_mask: (batch_size, seqlen), bool
"""
assert not (kvpacked and qkvpacked)
batch_size, seqlen_q, nheads, d = q.shape
_, seqlen_k, nheads_k, _ = k.shape
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
if query_padding_mask is not None:
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
q, query_padding_mask
)
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
output_unpad, indices_q, batch_size, seqlen_q
)
else:
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q_unpad.device,
)
max_seqlen_q = seqlen_q
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
output_unpad, "(b s) h d -> b s h d", b=batch_size
)
if key_padding_mask is not None:
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
else:
k_unpad = rearrange(k, "b s h d -> (b s) h d")
v_unpad = rearrange(v, "b s h d -> (b s) h d")
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=k_unpad.device,
)
max_seqlen_k = seqlen_k
if qkvpacked:
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
if kvpacked:
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
kv = torch.stack([k, v], dim=2)
return (
q_unpad,
kv_unpad,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
kv,
output_pad_fn,
) )
return ( return (
q_unpad, self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
k_unpad, None,
v_unpad, None,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
q,
k,
v,
output_pad_fn,
) )
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self,
attention_mask,
input_shape,
inputs_embeds,
past_key_values_length,
): # pylint: disable=unused-argument
# [bsz, seq_len]
return attention_mask
def replace_llama_attn_with_flash_attn():
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward

View File

@@ -1,140 +0,0 @@
"""
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
"""
import warnings
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import transformers.models.llama.modeling_llama
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
def hijack_llama_sdp_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
sdp_attention_forward
)
def sdp_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
#
# sdp-attn start
#
with torch.backends.cuda.sdp_kernel():
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
#
# sdp-attn end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value

View File

@@ -3,13 +3,13 @@ Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-g
""" """
import logging import logging
import warnings import math
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import transformers.models.llama.modeling_llama import transformers.models.llama.modeling_llama
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv from torch import nn
try: try:
import xformers.ops import xformers.ops
@@ -21,6 +21,12 @@ def hijack_llama_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
def hijack_llama_sdp_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
sdp_attention_forward
)
def xformers_forward( def xformers_forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -75,15 +81,15 @@ def xformers_forward(
value_states = value_states.view( value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2) ).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb( (
query_states,
key_states,
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids query_states, key_states, cos, sin, position_ids
) )
# [bsz, nh, t, hd] # [bsz, nh, t, hd]
@@ -96,50 +102,74 @@ def xformers_forward(
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads # repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = transformers.models.llama.modeling_llama.repeat_kv(
value_states = repeat_kv(value_states, self.num_key_value_groups) key_states, self.num_key_value_groups
)
value_states = transformers.models.llama.modeling_llama.repeat_kv(
value_states, self.num_key_value_groups
)
if output_attentions: # We only apply xformers optimizations if we don't need to output the whole attention matrix
warnings.warn( if not output_attentions:
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." query_states = query_states.transpose(1, 2)
) key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
# xformers-attn start # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
# if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
query_states = query_states.transpose(1, 2) attn_output = xformers.ops.memory_efficient_attention(
key_states = key_states.transpose(1, 2) query_states, key_states, value_states, attn_bias=None
value_states = value_states.transpose(1, 2) )
else:
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. # input and output should be of form (bsz, q_len, num_heads, head_dim)
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. attn_output = xformers.ops.memory_efficient_attention(
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: query_states,
# input and output should be of form (bsz, q_len, num_heads, head_dim) key_states,
attn_output = xformers.ops.memory_efficient_attention( value_states,
query_states, key_states, value_states, attn_bias=None # attn_bias=attention_mask,
) attn_bias=xformers.ops.LowerTriangularMask(),
)
attn_weights = None
else: else:
# input and output should be of form (bsz, q_len, num_heads, head_dim) attn_weights = torch.matmul(
attn_output = xformers.ops.memory_efficient_attention( query_states, key_states.transpose(2, 3)
query_states, ) / math.sqrt(self.head_dim)
key_states,
value_states, if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
# attn_bias=attention_mask, raise ValueError(
attn_bias=xformers.ops.LowerTriangularMask(), f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
) f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
# end x-formers vs. not x-formers if-else block
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
#
# xformers-attn end
#
if self.pretraining_tp > 1: if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2) attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split( o_proj_slices = self.o_proj.weight.split(
@@ -152,4 +182,103 @@ def xformers_forward(
else: else:
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value return attn_output, attn_weights, past_key_value
def sdp_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
(
query_states,
key_states,
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# We only apply sdp attention if we don't need to output the whole attention matrix
if not output_attentions:
with torch.backends.cuda.sdp_kernel():
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
)
attn_weights = None
else:
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value

View File

@@ -312,7 +312,9 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
if len(source) < 2: if len(source) < 2:
# If there isn't a back and forth conversation, ignore it # If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations # also happens on the data splitting leaving empty conversations
raise IndexError raise IndexError(
f"A conversation entry has less than 2 messages :\n{source}"
)
conv = self._conversation.copy() conv = self._conversation.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

View File

@@ -28,6 +28,9 @@ def gpu_memory_usage_smi(device=0):
def log_gpu_memory_usage(log, msg, device): def log_gpu_memory_usage(log, msg, device):
if not torch.cuda.is_available():
return (0, 0, 0)
usage, cache, misc = gpu_memory_usage_all(device) usage, cache, misc = gpu_memory_usage_all(device)
extras = [] extras = []
if cache > 0: if cache > 0:

View File

@@ -32,6 +32,45 @@ if TYPE_CHECKING:
from axolotl.utils.dict import DictDefault # noqa: F401 from axolotl.utils.dict import DictDefault # noqa: F401
def smart_tokenizer_and_embedding_resize(
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
resize_token_embeddings_multiple: Optional[int] = None,
):
"""Resize tokenizer and embedding.
Note: This function resizes the tokenizer to accommodate additional special tokens and the
embedding matrix of the model to match the new size of the tokenizer. If any new special tokens
have been added, the function computes the average embedding values of the existing embeddings
and sets those values for the new special token embeddings. This is done separately for the input
embeddings and output embeddings of the model.
"""
old_tokens = model.get_input_embeddings().weight.data.shape[0]
num_new_tokens = len(tokenizer) - old_tokens
embeddings_len = (
math.ceil(len(tokenizer) / resize_token_embeddings_multiple)
* resize_token_embeddings_multiple
if resize_token_embeddings_multiple
else len(tokenizer)
)
model.resize_token_embeddings(embeddings_len)
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def load_tokenizer(cfg): def load_tokenizer(cfg):
tokenizer_kwargs = {} tokenizer_kwargs = {}
use_fast = True # this is the default use_fast = True # this is the default
@@ -112,7 +151,9 @@ def load_model(
LOG.info("patching with xformers attention") LOG.info("patching with xformers attention")
hijack_llama_attention() hijack_llama_attention()
elif cfg.is_llama_derived_model and cfg.sdp_attention: elif cfg.is_llama_derived_model and cfg.sdp_attention:
from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_sdp_attention,
)
LOG.info("patching with sdp attention") LOG.info("patching with sdp attention")
hijack_llama_sdp_attention() hijack_llama_sdp_attention()
@@ -227,8 +268,12 @@ def load_model(
elif cfg.is_llama_derived_model and not cfg.trust_remote_code: elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
config_kwargs = {}
if cfg.rope_scaling:
config_kwargs["rope_scaling"] = cfg.rope_scaling
config = LlamaConfig.from_pretrained( config = LlamaConfig.from_pretrained(
base_model_config, rope_scaling=cfg.rope_scaling base_model_config,
**config_kwargs,
) )
model = LlamaForCausalLM.from_pretrained( model = LlamaForCausalLM.from_pretrained(
base_model, base_model,
@@ -321,17 +366,16 @@ def load_model(
**model_kwargs, **model_kwargs,
) )
embeddings_len = ( smart_tokenizer_and_embedding_resize(
math.ceil(len(tokenizer) / 32) * 32 tokenizer,
if cfg.resize_token_embeddings_to_32x model,
else len(tokenizer) resize_token_embeddings_multiple=cfg.resize_token_embeddings_multiple,
) )
model.resize_token_embeddings(embeddings_len)
if ( if (
hasattr(model.config, "max_position_embeddings") hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings and model.config.max_position_embeddings
and cfg.sequence_len >= model.config.max_position_embeddings and cfg.sequence_len > model.config.max_position_embeddings
): ):
LOG.warning( LOG.warning(
f"increasing model.config.max_position_embeddings to {cfg.sequence_len}" f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"

View File

@@ -440,6 +440,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
training_arguments_kwargs["push_to_hub"] = True training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True training_arguments_kwargs["hub_private_repo"] = True
if cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy
if cfg.save_safetensors: if cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
@@ -448,8 +451,17 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
"sample_packing_efficiency" "sample_packing_efficiency"
] = cfg.sample_packing_eff_est ] = cfg.sample_packing_eff_est
if cfg.val_set_size == 0:
evaluation_strategy = "no"
elif cfg.eval_steps < 1:
# eval every epoch
evaluation_strategy = "epoch"
else:
# eval every eval_steps steps
evaluation_strategy = "steps"
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len, max_seq_length=cfg.sequence_len,
per_device_train_batch_size=cfg.micro_batch_size, per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size per_device_eval_batch_size=cfg.eval_batch_size
@@ -459,7 +471,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
eval_accumulation_steps=cfg.gradient_accumulation_steps, eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs, num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate, learning_rate=cfg.learning_rate,
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no", evaluation_strategy=evaluation_strategy,
save_strategy="steps" if cfg.save_steps else "epoch", save_strategy="steps" if cfg.save_steps else "epoch",
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None, eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
save_steps=cfg.save_steps, save_steps=cfg.save_steps,