btlm and falcon monkey patches for flash attn (#566)
This commit is contained in:
90
examples/cerebras/btlm-ft.yml
Normal file
90
examples/cerebras/btlm-ft.yml
Normal file
@@ -0,0 +1,90 @@
|
||||
base_model: cerebras/btlm-3b-8k-base
|
||||
base_model_config: cerebras/btlm-3b-8k-base
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: GPT2Tokenizer
|
||||
trust_remote_code: true
|
||||
tokenizer_use_fast: true
|
||||
tokenizer_legacy: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
push_dataset_to_hub:
|
||||
hf_use_auth_token: true
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_prepared_run
|
||||
val_set_size: 0.01
|
||||
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
sample_packing: false
|
||||
sample_packing_eff_est:
|
||||
sample_packing_seq_len_multiplier:
|
||||
total_num_tokens:
|
||||
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
lora_target_modules:
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
output_dir: btlm-out
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_torch
|
||||
adam_beta2: 0.95
|
||||
adam_eps: 0.000000001
|
||||
max_grad_norm: 1.0
|
||||
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
lr_quadratic_warmup: true
|
||||
learning_rate: 0.000085
|
||||
train_on_inputs: true
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
sdp_attention:
|
||||
flash_optimum:
|
||||
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
|
||||
warmup_steps: 32
|
||||
eval_steps:
|
||||
save_steps:
|
||||
save_total_limit:
|
||||
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
special_tokens:
|
||||
pad_token: "<|endoftext|>"
|
||||
fsdp:
|
||||
# - full_shard
|
||||
# - auto_wrap
|
||||
fsdp_config:
|
||||
# fsdp_state_dict_type: FULL_STATE_DICT
|
||||
# fsdp_transformer_layer_cls_to_wrap: BTLMBlock
|
||||
64
src/axolotl/monkeypatch/btlm_attn_hijack_flash.py
Normal file
64
src/axolotl/monkeypatch/btlm_attn_hijack_flash.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Flash attention monkey patch for cerebras btlm model
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
|
||||
# this is a wonky hack to get the remotely loaded module
|
||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
# we need to load the model here in order for modeling_btlm to be available
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
module_name = model_config.__class__.__module__.replace(
|
||||
".configuration_btlm", ".modeling_btlm"
|
||||
)
|
||||
modeling_btlm = importlib.import_module(module_name)
|
||||
modeling_btlm.BTLMAttention._attn = ( # pylint: disable=protected-access
|
||||
flashattn_attn
|
||||
)
|
||||
|
||||
|
||||
def flashattn_attn(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
value: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
position_bias: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
softmax_scale = (
|
||||
1 / (key.size(-1) ** self.attn_scale_power) if self.scale_attn_weights else None
|
||||
)
|
||||
|
||||
query = query.permute(0, 2, 1, 3)
|
||||
key = key.permute(0, 2, 1, 3)
|
||||
value = value.permute(0, 2, 1, 3)
|
||||
|
||||
# Perform Flash attention
|
||||
attn_output = flash_attn_func(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p=0.0, # Assuming you have this attribute
|
||||
softmax_scale=softmax_scale, # Set this if you have specific scaling in mind
|
||||
causal=not self.is_cross_attention, # Assuming you have this attribute
|
||||
return_attn_probs=False, # Set this based on your needs
|
||||
)
|
||||
|
||||
# Optional: Apply head mask if it's not None
|
||||
if head_mask is not None:
|
||||
attn_output *= head_mask
|
||||
|
||||
attn_output = attn_output.permute(0, 2, 1, 3)
|
||||
|
||||
return attn_output, None # We don't have explicit attn_weights in Flash attention
|
||||
101
src/axolotl/monkeypatch/falcon_attn_hijack_flash.py
Normal file
101
src/axolotl/monkeypatch/falcon_attn_hijack_flash.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Flash Attention monkey patch for Falcon
|
||||
|
||||
copied from https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/falcon_flash_attn_monkey_patch.py
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor, # pylint: disable=unused-argument
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False, # pylint: disable=unused-argument
|
||||
):
|
||||
fused_qkv = self.query_key_value(
|
||||
hidden_states
|
||||
) # [batch_size, seq_length, 3 x hidden_size]
|
||||
num_kv_heads = (
|
||||
self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
||||
)
|
||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
||||
(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
) = self._split_heads( # pylint: disable=protected-access
|
||||
fused_qkv
|
||||
)
|
||||
|
||||
batch_size, query_length, _, _ = query_layer.shape
|
||||
|
||||
query_layer = query_layer.transpose(1, 2).reshape(
|
||||
batch_size * self.num_heads, query_length, self.head_dim
|
||||
)
|
||||
key_layer = key_layer.transpose(1, 2).reshape(
|
||||
batch_size * num_kv_heads,
|
||||
query_length,
|
||||
self.head_dim,
|
||||
)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(
|
||||
batch_size * num_kv_heads, query_length, self.head_dim
|
||||
)
|
||||
|
||||
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension:
|
||||
# - key: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=1)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||
|
||||
# unused
|
||||
# _, kv_length, _ = key_layer.shape
|
||||
if use_cache:
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
# unused
|
||||
# attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
|
||||
query_layer_ = (
|
||||
query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.to(torch.bfloat16)
|
||||
)
|
||||
key_layer_ = (
|
||||
key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.to(torch.bfloat16)
|
||||
)
|
||||
value_layer_ = (
|
||||
value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.to(torch.bfloat16)
|
||||
)
|
||||
|
||||
if alibi is not None:
|
||||
raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
|
||||
|
||||
# below output will have shape (batch_size, seqlen, nheads, headdim)
|
||||
attn_output = flash_attn_func(query_layer_, key_layer_, value_layer_, causal=True)
|
||||
attn_output = attn_output.reshape(
|
||||
batch_size, query_length, self.num_heads * self.head_dim
|
||||
)
|
||||
output_tensor = self.dense(attn_output)
|
||||
return output_tensor, present
|
||||
|
||||
|
||||
def replace_falcon_attn_with_flash_attn():
|
||||
transformers.models.falcon.modeling_falcon.FalconAttention.forward = forward
|
||||
@@ -100,10 +100,31 @@ def load_model(
|
||||
base_model = cfg.base_model
|
||||
base_model_config = cfg.base_model_config
|
||||
model_type = cfg.model_type
|
||||
model_config = load_model_config(cfg)
|
||||
|
||||
# TODO refactor as a kwarg
|
||||
load_in_8bit = cfg.load_in_8bit
|
||||
|
||||
if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
|
||||
if cfg.flash_attention:
|
||||
from axolotl.monkeypatch.btlm_attn_hijack_flash import (
|
||||
replace_btlm_attn_with_flash_attn,
|
||||
)
|
||||
|
||||
replace_btlm_attn_with_flash_attn(cfg.base_model)
|
||||
|
||||
if hasattr(model_config, "model_type") and model_config.model_type in [
|
||||
"falcon",
|
||||
"RefinedWebModel",
|
||||
"RefinedWeb",
|
||||
]:
|
||||
if cfg.flash_attention:
|
||||
from axolotl.monkeypatch.falcon_attn_hijack_flash import (
|
||||
replace_falcon_attn_with_flash_attn,
|
||||
)
|
||||
|
||||
replace_falcon_attn_with_flash_attn()
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||
if cfg.device not in ["mps", "cpu"] and not inference:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
@@ -338,6 +359,9 @@ def load_model(
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name:
|
||||
module.to(torch.float32)
|
||||
if model_config.model_type == "btlm":
|
||||
# don't upcast lm_head for btlm
|
||||
continue
|
||||
if "lm_head" in name or "embed_tokens" in name:
|
||||
if hasattr(module, "weight"):
|
||||
module.to(torch.float32)
|
||||
|
||||
Reference in New Issue
Block a user