diff --git a/docs/config.qmd b/docs/config.qmd index 18fc9dcf8..fae64501a 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -73,11 +73,12 @@ load_in_8bit: true load_in_4bit: # Use CUDA bf16 -bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere +bf16: true # bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection. require >=ampere # Use CUDA fp16 fp16: true # Use CUDA tf32 tf32: true # require >=ampere +# Note: if bf16 is set to 'auto', and fp16 is set to true, we will prefer the explict fp16 setting # No AMP (automatic mixed precision) bfloat16: true # require >=ampere diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 99d97800c..af9a43db3 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -21,6 +21,7 @@ import importlib.util import inspect import logging import math +import os import sys from abc import abstractmethod from pathlib import Path @@ -72,6 +73,7 @@ from axolotl.utils.callbacks import ( SaveBetterTransformerModelCallback, bench_eval_callback_factory, causal_lm_bench_eval_callback_factory, + colab_inference_post_train_callback, log_prediction_callback_factory, ) from axolotl.utils.callbacks.lisa import lisa_callback_factory @@ -293,6 +295,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers: callbacks.append(lisa_callback_factory(trainer)) + if any("COLAB_" in key for key in os.environ): + ColabCallback = colab_inference_post_train_callback(trainer) + callbacks.append(ColabCallback(self.cfg)) + callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer)) return callbacks diff --git a/src/axolotl/monkeypatch/attention/__init__.py b/src/axolotl/monkeypatch/attention/__init__.py index e69de29bb..15ed764f4 100644 --- a/src/axolotl/monkeypatch/attention/__init__.py +++ b/src/axolotl/monkeypatch/attention/__init__.py @@ -0,0 +1,19 @@ +""" +attention module for attention monkeypatches +""" + +from transformers.integrations.flash_attention import flash_attention_forward + + +def patch_xformers_attn_over_fa2(): + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + from .xformers import xformers_attention_forward + + ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = xformers_attention_forward + + +def unpatch_xformers_attn_over_fa2(): + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward() diff --git a/src/axolotl/monkeypatch/attention/xformers.py b/src/axolotl/monkeypatch/attention/xformers.py new file mode 100644 index 000000000..5901963f0 --- /dev/null +++ b/src/axolotl/monkeypatch/attention/xformers.py @@ -0,0 +1,160 @@ +""" +xformers attention implementation for packing +""" + +from typing import Optional + +import torch +import xformers +import xformers.ops.fmha +from transformers.modeling_flash_attention_utils import ( + _upad_input, +) + +from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids + +xformers_attention = xformers.ops.fmha.memory_efficient_attention + + +def xformers_attention_forward( + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + dropout: float = 0.0, # pylint: disable=unused-argument + scaling: Optional[float] = None, # pylint: disable=unused-argument + sliding_window: Optional[int] = None, # pylint: disable=unused-argument + softcap: Optional[float] = None, # pylint: disable=unused-argument + cu_seq_lens_q: Optional[torch.LongTensor] = None, + cu_seq_lens_k: Optional[torch.LongTensor] = None, + max_length_q: Optional[int] = None, + max_length_k: Optional[int] = None, # pylint: disable=unused-argument + **kwargs, # pylint: disable=unused-argument +): + # Get dimensions + # query: [batch, heads, seq_len, hidden_dim] + batch_size = query.size(0) + query_length = query.shape[2] + key_length = key.shape[2] + + # Default causal mask + attn_bias = xformers.ops.LowerTriangularMask() + + # Check if we have sliding window attention + has_sliding_window = sliding_window is not None and sliding_window < query_length + + # Transpose dimensions for xformers (Q: [b, h, s, d] -> [b, s, h, d]) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # Get GQA parameters + num_attention_heads = module.config.num_attention_heads + num_key_value_heads = module.config.num_key_value_heads + head_dim = query.size(-1) + is_gqa = num_attention_heads != num_key_value_heads + n_groups = num_attention_heads // num_key_value_heads if is_gqa else 1 + + # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing + # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. + # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach + if position_ids is not None and ( + max_length_q is not None + or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()) + ): + if cu_seq_lens_q is None or cu_seq_lens_k is None: + cu_seq_lens_q = get_cu_seqlens_from_pos_ids(position_ids)[0] + cu_seq_lens_q = cu_seq_lens_q.squeeze() + seq_lengths = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1] + attn_bias = ( + xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=seq_lengths.tolist(), + ) + ) + else: + query = query.reshape(-1, query.size(-2), query.size(-1)) + key = key.reshape(-1, key.size(-2), key.size(-1)) + value = value.reshape(-1, value.size(-2), value.size(-1)) + + # Handle GQA + if is_gqa: + key = key.repeat_interleave(n_groups, dim=2) + value = value.repeat_interleave(n_groups, dim=2) + + elif attention_mask is not None: + query, key, value, _, cu_seq_lens, _ = _upad_input( + query, key, value, attention_mask, query_length + ) + cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens + seq_lengths = [] + for i in range(len(cu_seq_lens_q) - 1): + seq_lengths.append(cu_seq_lens_q[i + 1] - cu_seq_lens_q[i]) + attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens( + q_seqlen=seq_lengths, + kv_seqlen=seq_lengths, + ) + + # Handle GQA + if is_gqa: + key = key.repeat_interleave(n_groups, dim=2) + value = value.repeat_interleave(n_groups, dim=2) + else: + # Handle Group Query Attention (GQA) using view/expand approach from reference + key = key.view(batch_size, key_length, num_key_value_heads, 1, head_dim) + value = value.view(batch_size, key_length, num_key_value_heads, 1, head_dim) + key = key.expand( + batch_size, key_length, num_key_value_heads, n_groups, head_dim + ) + value = value.expand( + batch_size, key_length, num_key_value_heads, n_groups, head_dim + ) + + if module.training: + key = key.reshape(batch_size, key_length, num_attention_heads, head_dim) + value = value.reshape(batch_size, key_length, num_attention_heads, head_dim) + + if has_sliding_window: + query = query.view( + 1, batch_size * query_length, num_attention_heads, head_dim + ) + key = key.view( + 1, batch_size * key_length, num_attention_heads, head_dim + ) + value = value.view( + 1, batch_size * key_length, num_attention_heads, head_dim + ) + else: + query = query.view( + batch_size, query_length, num_key_value_heads, n_groups, head_dim + ) + + # If we need a sliding window attention + if has_sliding_window: + query = query.view( + 1, + batch_size * query_length, + num_key_value_heads, + n_groups, + head_dim, + ) + key = key.view( + 1, batch_size * key_length, num_key_value_heads, n_groups, head_dim + ) + value = value.view( + 1, batch_size * key_length, num_key_value_heads, n_groups, head_dim + ) + + # Run the xformers attention + attn_output = xformers_attention( + query, + key, + value, + attn_bias=attn_bias, + ) + + attn_output = attn_output.view( + batch_size, -1, attn_output.size(-2), attn_output.size(-1) + ) + return attn_output, None diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 21b14d986..0e7b06093 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -868,3 +868,28 @@ class GCCallback(TrainerCallback): ): torch.cuda.empty_cache() gc.collect() + + +def colab_inference_post_train_callback(trainer: Trainer): + class ColabCallback(TrainerCallback): + """Callback to prep model for inference on Google Colab""" + + def __init__(self, cfg): + self.gpu_name = torch.cuda.get_device_name(0) + self.cfg = cfg + + def on_train_end( + self, args, state, control, **kwargs + ): # pylint: disable=unused-argument + """ + handle T4 gpu, we need to convert attention to eager for inference + """ + if "Tesla T4" in self.gpu_name and self.cfg.xformers_attention: + trainer.model.config._attn_implementation = ( # pylint: disable=protected-access + "eager" + ) + trainer.model.gradient_checkpointing_disable() + trainer.model.config.use_cache = True + trainer.model.eval() + + return ColabCallback diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 0de87fa5c..a96cc1286 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -70,6 +70,9 @@ def resolve_dtype(cfg): if cfg.fp16 is None and not cfg.float16: cfg.fp16 = True + if cfg.fp16 and cfg.bf16 == "auto": + cfg.bf16 = False + if cfg.device == "mps": cfg.load_in_8bit = False cfg.tf32 = False diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index e88de1bad..d30ec0b56 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -540,6 +540,11 @@ class ModelLoader: self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name def apply_patches(self) -> None: + if self.cfg.xformers_attention and self.cfg.sample_packing: + from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2 + + patch_xformers_attn_over_fa2() + self.cfg.flash_attention = True if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2": from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 02308695c..3527ec56e 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -435,16 +435,6 @@ class AxolotlInputConfig( ) return data - @model_validator(mode="before") - @classmethod - def check_sample_packing_w_xformers(cls, data): - if data.get("sample_packing") and data.get("xformers_attention"): - raise ValueError( - "sample_packing not compatible with xformers_attention. Use flash_attention" - ) - - return data - @model_validator(mode="before") @classmethod # pylint: disable=duplicate-code @@ -471,9 +461,10 @@ class AxolotlInputConfig( and not data.get("flash_attention") and not data.get("sdp_attention") and not data.get("flex_attention") + and not data.get("xformers_attention") ): LOG.warning( - "sample_packing without flash, sdp or flex attention does not handle cross sample decontamination." + "sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination." ) return data