diff --git a/src/axolotl/monkeypatch/attention/xformers.py b/src/axolotl/monkeypatch/attention/xformers.py index 394169eef..5901963f0 100644 --- a/src/axolotl/monkeypatch/attention/xformers.py +++ b/src/axolotl/monkeypatch/attention/xformers.py @@ -7,7 +7,6 @@ from typing import Optional import torch import xformers import xformers.ops.fmha -from flash_attn.bert_padding import pad_input from transformers.modeling_flash_attention_utils import ( _upad_input, ) @@ -34,14 +33,29 @@ def xformers_attention_forward( max_length_k: Optional[int] = None, # pylint: disable=unused-argument **kwargs, # pylint: disable=unused-argument ): + # Get dimensions + # query: [batch, heads, seq_len, hidden_dim] + batch_size = query.size(0) + query_length = query.shape[2] + key_length = key.shape[2] + # Default causal mask + attn_bias = xformers.ops.LowerTriangularMask() + + # Check if we have sliding window attention + has_sliding_window = sliding_window is not None and sliding_window < query_length + + # Transpose dimensions for xformers (Q: [b, h, s, d] -> [b, s, h, d]) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) - query_length = query.shape[2] - batch_size = query.size(0) - attn_bias = xformers.ops.LowerTriangularMask() + # Get GQA parameters + num_attention_heads = module.config.num_attention_heads + num_key_value_heads = module.config.num_key_value_heads + head_dim = query.size(-1) + is_gqa = num_attention_heads != num_key_value_heads + n_groups = num_attention_heads // num_key_value_heads if is_gqa else 1 # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. @@ -64,25 +78,13 @@ def xformers_attention_forward( key = key.reshape(-1, key.size(-2), key.size(-1)) value = value.reshape(-1, value.size(-2), value.size(-1)) - # pylint: disable=duplicate-code - if module.config.num_attention_heads != module.config.num_key_value_heads: - key = key.repeat_interleave( - module.config.num_attention_heads // module.config.num_key_value_heads, - dim=2, - ) - value = value.repeat_interleave( - module.config.num_attention_heads // module.config.num_key_value_heads, - dim=2, - ) + # Handle GQA + if is_gqa: + key = key.repeat_interleave(n_groups, dim=2) + value = value.repeat_interleave(n_groups, dim=2) - attn_output = xformers_attention( - query, - key, - value, - attn_bias=attn_bias, - ) elif attention_mask is not None: - query, key, value, indices_q, cu_seq_lens, _ = _upad_input( + query, key, value, _, cu_seq_lens, _ = _upad_input( query, key, value, attention_mask, query_length ) cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens @@ -94,31 +96,63 @@ def xformers_attention_forward( kv_seqlen=seq_lengths, ) - attn_output_unpad = xformers_attention( - query, - key, - value, - attn_bias=attn_bias, - ) - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + # Handle GQA + if is_gqa: + key = key.repeat_interleave(n_groups, dim=2) + value = value.repeat_interleave(n_groups, dim=2) else: - # pylint: disable=duplicate-code - if module.config.num_attention_heads != module.config.num_key_value_heads: - key = key.repeat_interleave( - module.config.num_attention_heads // module.config.num_key_value_heads, - dim=2, - ) - value = value.repeat_interleave( - module.config.num_attention_heads // module.config.num_key_value_heads, - dim=2, + # Handle Group Query Attention (GQA) using view/expand approach from reference + key = key.view(batch_size, key_length, num_key_value_heads, 1, head_dim) + value = value.view(batch_size, key_length, num_key_value_heads, 1, head_dim) + key = key.expand( + batch_size, key_length, num_key_value_heads, n_groups, head_dim + ) + value = value.expand( + batch_size, key_length, num_key_value_heads, n_groups, head_dim + ) + + if module.training: + key = key.reshape(batch_size, key_length, num_attention_heads, head_dim) + value = value.reshape(batch_size, key_length, num_attention_heads, head_dim) + + if has_sliding_window: + query = query.view( + 1, batch_size * query_length, num_attention_heads, head_dim + ) + key = key.view( + 1, batch_size * key_length, num_attention_heads, head_dim + ) + value = value.view( + 1, batch_size * key_length, num_attention_heads, head_dim + ) + else: + query = query.view( + batch_size, query_length, num_key_value_heads, n_groups, head_dim ) - attn_output = xformers_attention( - query, - key, - value, - attn_bias=attn_bias, - ) + # If we need a sliding window attention + if has_sliding_window: + query = query.view( + 1, + batch_size * query_length, + num_key_value_heads, + n_groups, + head_dim, + ) + key = key.view( + 1, batch_size * key_length, num_key_value_heads, n_groups, head_dim + ) + value = value.view( + 1, batch_size * key_length, num_key_value_heads, n_groups, head_dim + ) + + # Run the xformers attention + attn_output = xformers_attention( + query, + key, + value, + attn_bias=attn_bias, + ) attn_output = attn_output.view( batch_size, -1, attn_output.size(-2), attn_output.size(-1)