diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 7d0df8a45..40dedb456 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -891,7 +891,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): if "max_length" in kwargs: kwargs.pop("max_length") elif use_batch_sampler_collator: - if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES or ( + if self.cfg.flex_attention: + collator = V2BatchSamplerDataCollatorForSeq2Seq + elif self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: + collator = V2BatchSamplerDataCollatorForSeq2Seq + elif ( self.cfg.model_config_type in ["llama"] and self.cfg.flash_attention is not True ): diff --git a/src/axolotl/monkeypatch/attention/flex_attn.py b/src/axolotl/monkeypatch/attention/flex_attn.py new file mode 100644 index 000000000..8b69c2c49 --- /dev/null +++ b/src/axolotl/monkeypatch/attention/flex_attn.py @@ -0,0 +1,48 @@ +"""Flex attention monkey patch""" + +import torch +import transformers + + +def patch_flex(): + is_torch_2_6 = torch.__version__.startswith("2.6") + is_transformers_below_4_51 = transformers.__version__ < "4.51.0" + + if is_torch_2_6 and is_transformers_below_4_51: + from torch.nn.attention.flex_attention import flex_attention + + class WrappedFlexAttention: + """ + We are doing a singleton class so that flex attention is compiled once when it's first called. + """ + + _instance = None + _is_flex_compiled = False + _compiled_flex_attention = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + # Create a new instance if one doesn't already exist + cls._instance = super().__new__(cls) + return cls._instance + + @torch.compiler.disable(recursive=False) + def __init__(self): + """ + Initialize or update the singleton instance. + """ + if not self._is_flex_compiled: + self._compiled_flex_attention = torch.compile( + flex_attention, + dynamic=False, + mode="max-autotune-no-cudagraphs", + fullgraph=True, + ) + self._is_flex_compiled = True + + def __call__(self): + return self._compiled_flex_attention + + transformers.integrations.flex_attention.WrappedFlexAttention = ( + WrappedFlexAttention + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 301607865..663aa1740 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -578,7 +578,7 @@ class ModelLoader: if ( self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES - and self.cfg.flash_attention + and (self.cfg.flash_attention or self.cfg.flex_attention) and self.cfg.sample_packing ): if "auto_map" in self.model_config: @@ -884,7 +884,16 @@ class ModelLoader: """ sample packing uses custom FA2 patch """ - if self.cfg.flash_attention: + if self.cfg.flex_attention: + self.model_kwargs["attn_implementation"] = "flex_attention" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "flex_attention" + ) + from axolotl.monkeypatch.attention.flex_attn import patch_flex + + patch_flex() + + elif self.cfg.flash_attention: if not self.cfg.sample_packing and self.cfg.s2_attention: pass self.model_kwargs["attn_implementation"] = "flash_attention_2" @@ -1281,7 +1290,10 @@ class ModelLoader: should_convert = ( # LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to # convert them back to fp16/bf16 for flash-attn compatibility. - ((needs_fa2_dtype or self.cfg.flash_attention) and not qlora_fsdp) + ( + (needs_fa2_dtype or self.cfg.flash_attention or self.cfg.flex_attention) + and not qlora_fsdp + ) or self.cfg.cut_cross_entropy # Cut cross entropy requires embedding layers to be in fp16/bf16 for backward pass ) diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index c9a208ba2..cf98f7f02 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -223,6 +223,7 @@ class AxolotlInputConfig( xformers_attention: bool | None = None sdp_attention: bool | None = None s2_attention: bool | None = None + flex_attention: bool | None = None flash_attention: bool | None = None flash_attn_cross_entropy: bool | None = None flash_attn_rms_norm: bool | None = None @@ -355,6 +356,22 @@ class AxolotlInputConfig( return [ds_config.model_dump(exclude_none=True) for ds_config in ds_configs] return None + @model_validator(mode="before") + @classmethod + def check_attention_fields(cls, data): + fields = ( + "xformers_attention", + "sdp_attention", + "s2_attention", + "flash_attention", + "flex_attention", + ) + non_empty_count = sum(1 for field in fields if data.get(field)) + + if non_empty_count > 1: + raise ValueError(f"Only one of {', '.join(fields)} must be set") + return data + @model_validator(mode="before") @classmethod def check_batch_size_fields(cls, data): @@ -1250,6 +1267,24 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): ) return data + @model_validator(mode="before") + @classmethod + def check_flex_torch_version(cls, data): + if (data.get("flex_attention") is not None) and (data.get("flex_attention")): + env_capabilities = data.get("env_capabilities", {}) + torch_version = env_capabilities.get("torch_version") + + if torch_version is None: + import torch + + torch_version = str(torch.__version__).split("+", maxsplit=1)[0] + + if version.parse(torch_version) < version.parse("2.6.0"): + raise ValueError( + "Flex attention is not supported on torch version < 2.6.0" + ) + return data + @model_validator(mode="before") @classmethod def check_torch_compile_auto(cls, data): diff --git a/tests/e2e/multigpu/solo/test_flex.py b/tests/e2e/multigpu/solo/test_flex.py new file mode 100644 index 000000000..5e2c9e7cc --- /dev/null +++ b/tests/e2e/multigpu/solo/test_flex.py @@ -0,0 +1,92 @@ +""" +E2E tests for multigpu lora tinyllama +""" + +import logging +import os +from pathlib import Path + +import pytest +import yaml +from accelerate.test_utils import execute_subprocess_async +from huggingface_hub import snapshot_download +from transformers.testing_utils import get_torch_dist_unique_port +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.utils.dict import DictDefault + +from tests.e2e.utils import check_tensorboard, require_torch_2_6_0 + +LOG = logging.getLogger("axolotl.tests.e2e.multigpu") +os.environ["WANDB_DISABLED"] = "true" + +AXOLOTL_ROOT = Path(__file__).parent.parent.parent.parent + + +@pytest.fixture(scope="session", autouse=True) +def download_model(): + # download the model + snapshot_download("HuggingFaceTB/SmolLM2-135M") + + +class TestPackedFlex: + """ + Test case for Packed training of llama models + """ + + @require_torch_2_6_0 + def test_loss_llama(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sequence_len": 1024, + "sample_packing": True, + "flex_attention": True, + "val_set_size": 0.0, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "vicgalle/alpaca-gpt4", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 2, + "gradient_accumulation_steps": 4, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "max_steps": 5, + "use_tensorboard": True, + "save_strategy": "no", + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "2", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) + + check_tensorboard( + temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" + ) diff --git a/tests/e2e/solo/test_flex.py b/tests/e2e/solo/test_flex.py new file mode 100644 index 000000000..5e52204a6 --- /dev/null +++ b/tests/e2e/solo/test_flex.py @@ -0,0 +1,73 @@ +""" +E2E tests for packed training w/ flex attention +""" + +import logging +import os +import unittest + +from transformers.utils import is_torch_bf16_gpu_available + +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, validate_config +from axolotl.utils.dict import DictDefault + +from ..utils import check_tensorboard, require_torch_2_6_0, with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestPackedFlex(unittest.TestCase): + """ + Test case for Packed training of llama models + """ + + @require_torch_2_6_0 + @with_temp_dir + def test_loss_llama(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sequence_len": 1024, + "sample_packing": True, + "flex_attention": True, + "val_set_size": 0.0, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "vicgalle/alpaca-gpt4", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 2, + "gradient_accumulation_steps": 4, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_torch_fused", + "lr_scheduler": "cosine", + "max_steps": 5, + "use_tensorboard": True, + } + ) + if is_torch_bf16_gpu_available(): + cfg.bf16 = True + else: + cfg.fp16 = True + + cfg = validate_config(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, dataset_meta=dataset_meta) + + check_tensorboard( + temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" + ) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 2b218fbf5..2fbf333c4 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -67,9 +67,21 @@ def require_torch_2_5_1(test_case): return unittest.skipUnless(is_min_2_5_1(), "test requires torch>=2.5.1")(test_case) +def require_torch_2_6_0(test_case): + """ + Decorator marking a test that requires torch >= 2.6.0 + """ + + def is_min_2_6_0(): + torch_version = version.parse(torch.__version__) + return torch_version >= version.parse("2.6.0") + + return unittest.skipUnless(is_min_2_6_0(), "test requires torch>=2.6.0")(test_case) + + def require_torch_lt_2_6_0(test_case): """ - Decorator marking a test that requires torch >= 2.5.1 + Decorator marking a test that requires torch < 2.6.0 """ def is_max_2_6_0():