Add shifted sparse attention (#973) [skip-ci]
* Add s2_attn to hijack flash code * Refactor code to account for s2_attn * Add test for models utils * Add ``s2_attention`` option to llama configs * Add ``s2_attention`` option to README config * Format code to appease linter * chore: lint * Remove xpos and llama-landmark [bad merge] * add e2e smoke tests for shifted sparse attention * remove stray patch from merge * update yml with link to paper for s2_attention/longlora * fix assertion check for full fine tune * increase sequence len for tests and PR feedback updates * reduce context len to 16k for tests * reduce context len to 16k for tests * reduce batch size for larger context len and udpate test to check message * fix test for message --------- Co-authored-by: joecummings <jrcummings@devvm050.nha0.facebook.com> Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -834,7 +834,8 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
||||
# Whether to use scaled-dot-product attention
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
sdp_attention:
|
||||
|
||||
# Shifted-sparse attention (only llama) - https://arxiv.org/pdf/2309.12307.pdf
|
||||
s2_attention:
|
||||
# Resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
||||
|
||||
@@ -52,6 +52,7 @@ local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
s2_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -52,6 +52,7 @@ local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
s2_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -52,6 +52,7 @@ local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
s2_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -52,6 +52,7 @@ local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
s2_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -52,6 +52,7 @@ logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
gptq_groupsize:
|
||||
s2_attention:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 20
|
||||
evals_per_epoch: 4
|
||||
|
||||
@@ -70,11 +70,20 @@ def replace_llama_attn_with_flash_attn(
|
||||
packed: Optional[bool] = False,
|
||||
cross_entropy: Optional[bool] = False,
|
||||
rms_norm: Optional[bool] = False,
|
||||
use_shifted_sparse_attn: Optional[bool] = False,
|
||||
):
|
||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||
_prepare_decoder_attention_mask
|
||||
)
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
|
||||
if use_shifted_sparse_attn:
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||
flashattn_forward_with_s2attn
|
||||
)
|
||||
else:
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||
flashattn_forward
|
||||
)
|
||||
|
||||
if packed:
|
||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
||||
transformers.models.llama.modeling_llama.LlamaModel.forward = (
|
||||
@@ -213,6 +222,136 @@ def _prepare_decoder_attention_mask(
|
||||
return attention_mask
|
||||
|
||||
|
||||
GROUP_SIZE_RATIO = 1 / 4
|
||||
|
||||
|
||||
def flashattn_forward_with_s2attn(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
cu_seqlens: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
||||
max_seqlen: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
"""Input shape: Batch x Time x Channel
|
||||
|
||||
From: https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py
|
||||
|
||||
attention_mask: [bsz, q_len]
|
||||
|
||||
`cu_seqlens` will be ignored if provided
|
||||
`max_seqlen` will be ignored if provided
|
||||
"""
|
||||
if output_attentions:
|
||||
warnings.warn(
|
||||
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
||||
)
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = (
|
||||
self.q_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states)
|
||||
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
# [bsz, q_len, nh, hd]
|
||||
# [bsz, nh, q_len, hd]
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
|
||||
# Past Key value support
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
# Flash attention codes from
|
||||
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
|
||||
|
||||
# transform the data into the format required by flash attention
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
|
||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||
# the attention_mask should be the same as the key_padding_mask
|
||||
|
||||
key_padding_mask = attention_mask.repeat(2, 1)
|
||||
nheads = qkv.shape[-2]
|
||||
# shift
|
||||
|
||||
group_size = int(q_len * GROUP_SIZE_RATIO)
|
||||
if q_len % group_size > 0:
|
||||
raise ValueError(
|
||||
f"q_len {q_len} should be divisible by group size {group_size}."
|
||||
)
|
||||
|
||||
qkv = (
|
||||
qkv.reshape(bsz, q_len, 3, 2, self.num_heads // 2, self.head_dim)
|
||||
.permute(0, 3, 1, 2, 4, 5)
|
||||
.reshape(bsz * 2, q_len, 3, self.num_heads // 2, self.head_dim)
|
||||
)
|
||||
x = rearrange( # pylint: disable=invalid-name
|
||||
qkv, "b s three h d -> b s (three h d)"
|
||||
)
|
||||
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
||||
cu_q_len_tmp = torch.arange(
|
||||
0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype
|
||||
)
|
||||
cu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp + group_size // 2]).repeat(
|
||||
bsz, 1
|
||||
) + cu_q_lens[:-1].unsqueeze(-1)
|
||||
cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1)
|
||||
|
||||
x_unpad = rearrange(
|
||||
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads // 2
|
||||
)
|
||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||
x_unpad, cu_q_lens, group_size, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(
|
||||
pad_input(
|
||||
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz * 2, q_len
|
||||
),
|
||||
"b s (h d) -> b s h d",
|
||||
h=nheads // 2,
|
||||
)
|
||||
output = (
|
||||
output.reshape(bsz, 2, q_len, nheads // 2, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.reshape(bsz, q_len, nheads, self.head_dim)
|
||||
)
|
||||
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value
|
||||
|
||||
|
||||
def flashattn_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -256,31 +256,55 @@ def load_model(
|
||||
|
||||
replace_stablelm_attn_with_flash_attn(cfg.base_model)
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
|
||||
if cfg.device not in ["mps", "cpu"] and not inference:
|
||||
if cfg.sample_packing and cfg.s2_attention:
|
||||
raise ValueError(
|
||||
"Received `sample_packing=true` and `s2_attention=true`; however, \
|
||||
shifted-sparse attention does not currently support sample packing."
|
||||
)
|
||||
|
||||
# Modify all llama derived models in one block
|
||||
if cfg.is_llama_derived_model:
|
||||
if cfg.flash_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
replace_llama_attn_with_flash_attn,
|
||||
)
|
||||
|
||||
LOG.info("patching with flash attention for sample packing")
|
||||
replace_llama_attn_with_flash_attn(
|
||||
packed=cfg.sample_packing,
|
||||
cross_entropy=cfg.flash_attn_cross_entropy,
|
||||
rms_norm=cfg.flash_attn_rms_norm,
|
||||
if cfg.sample_packing:
|
||||
if cfg.device not in ["mps", "cpu"] and not inference:
|
||||
LOG.info("patching with flash attention for sample packing")
|
||||
replace_llama_attn_with_flash_attn(
|
||||
packed=True,
|
||||
cross_entropy=cfg.flash_attn_cross_entropy,
|
||||
rms_norm=cfg.flash_attn_rms_norm,
|
||||
)
|
||||
elif cfg.s2_attention:
|
||||
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
|
||||
replace_llama_attn_with_flash_attn(
|
||||
packed=False,
|
||||
cross_entropy=cfg.flash_attn_cross_entropy,
|
||||
rms_norm=cfg.flash_attn_rms_norm,
|
||||
use_shifted_sparse_attn=True,
|
||||
)
|
||||
elif cfg.xformers_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_attention,
|
||||
)
|
||||
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_attention,
|
||||
)
|
||||
|
||||
LOG.info("patching with xformers attention")
|
||||
hijack_llama_attention()
|
||||
elif cfg.is_llama_derived_model and cfg.sdp_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_sdp import hijack_llama_sdp_attention
|
||||
LOG.info("patching with xformers attention")
|
||||
hijack_llama_attention()
|
||||
elif cfg.sdp_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_sdp import (
|
||||
hijack_llama_sdp_attention,
|
||||
)
|
||||
|
||||
LOG.info("patching with sdp attention")
|
||||
hijack_llama_sdp_attention()
|
||||
LOG.info("patching with sdp attention")
|
||||
hijack_llama_sdp_attention()
|
||||
elif cfg.s2_attention:
|
||||
raise NotImplementedError(
|
||||
"Shifted-sparse attention not currently implemented without flash attention."
|
||||
)
|
||||
|
||||
# Modify mistral derived models
|
||||
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
|
||||
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
||||
replace_mistral_attn_with_flash_attn,
|
||||
@@ -387,9 +411,12 @@ def load_model(
|
||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**bnb_config,
|
||||
)
|
||||
|
||||
# sample packing uses custom FA2 patch
|
||||
if cfg.flash_attention:
|
||||
if not cfg.sample_packing:
|
||||
if cfg.s2_attention:
|
||||
pass
|
||||
if (
|
||||
cfg.is_llama_derived_model
|
||||
or cfg.is_falcon_derived_model
|
||||
|
||||
111
tests/e2e/patched/test_llama_s2_attention.py
Normal file
111
tests/e2e/patched/test_llama_s2_attention.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
E2E tests for llama w/ S2 attn
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models using S2 Attn
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_s2_attn(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 16384,
|
||||
"sample_packing": False,
|
||||
"flash_attention": True,
|
||||
"s2_attention": True,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "Yukang/LongAlpaca-12k",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 10,
|
||||
"save_steps": 5,
|
||||
"eval_steps": 5,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft_s2_attn(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 16384,
|
||||
"sample_packing": False,
|
||||
"flash_attention": True,
|
||||
"s2_attention": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "Yukang/LongAlpaca-12k",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 10,
|
||||
"save_steps": 5,
|
||||
"eval_steps": 5,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
37
tests/utils/test_models.py
Normal file
37
tests/utils/test_models.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Module for testing models utils file."""
|
||||
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model
|
||||
|
||||
|
||||
class ModelsUtilsTest(unittest.TestCase):
|
||||
"""Testing module for models utils."""
|
||||
|
||||
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"s2_attention": True,
|
||||
"sample_packing": True,
|
||||
"base_model": "",
|
||||
"model_type": "LlamaForCausalLM",
|
||||
}
|
||||
)
|
||||
|
||||
# Mock out call to HF hub
|
||||
with patch(
|
||||
"axolotl.utils.models.load_model_config"
|
||||
) as mocked_load_model_config:
|
||||
mocked_load_model_config.return_value = {}
|
||||
with pytest.raises(ValueError) as exc:
|
||||
# Should error before hitting tokenizer, so we pass in an empty str
|
||||
load_model(cfg, tokenizer="")
|
||||
assert (
|
||||
"shifted-sparse attention does not currently support sample packing"
|
||||
in str(exc.value)
|
||||
)
|
||||
Reference in New Issue
Block a user