diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 73eddd426..b1c37d8ba 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -28,6 +28,7 @@ from transformers import ( from transformers.trainer_utils import seed_worker from trl import DPOTrainer +from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.utils.callbacks import ( EvalFirstStepCallback, @@ -994,7 +995,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ] ] if use_batch_sampler_collator: - if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]: + if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES: collator = V2BatchSamplerDataCollatorForSeq2Seq elif ( self.cfg.model_config_type in ["llama"] diff --git a/src/axolotl/monkeypatch/falcon/__init__.py b/src/axolotl/monkeypatch/falcon/__init__.py deleted file mode 100644 index dc6e526f6..000000000 --- a/src/axolotl/monkeypatch/falcon/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Patches to support multipack for falcon -""" -import transformers - -from axolotl.monkeypatch.utils import get_unpad_data - - -def replace_falcon_attn_with_multipack_flash_attn(): - transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) diff --git a/src/axolotl/monkeypatch/mixtral/__init__.py b/src/axolotl/monkeypatch/mixtral/__init__.py index 112b0dee9..d6ee0ce16 100644 --- a/src/axolotl/monkeypatch/mixtral/__init__.py +++ b/src/axolotl/monkeypatch/mixtral/__init__.py @@ -2,9 +2,6 @@ Patches to support multipack for mixtral """ import torch -import transformers - -from axolotl.monkeypatch.utils import get_unpad_data def patch_mixtral_moe_forward_zero3() -> None: @@ -51,11 +48,3 @@ def patch_mixtral_moe_forward_zero3() -> None: MixtralBLockSparseTop2MLP.forward = mlp_forward MixtralSparseMoeBlock.forward = moe_forward - - -def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False): - transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) - if for_zero3: - patch_mixtral_moe_forward_zero3() diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py new file mode 100644 index 000000000..640b3b0c3 --- /dev/null +++ b/src/axolotl/monkeypatch/multipack.py @@ -0,0 +1,30 @@ +"""multipack patching for v2 of sample packing""" + +import transformers +from transformers.integrations import is_deepspeed_zero3_enabled + +from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3 +from axolotl.monkeypatch.utils import get_unpad_data + +SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"] + + +def patch_for_multipack(model_type): + if model_type == "mixtral": + transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + if is_deepspeed_zero3_enabled(): + patch_mixtral_moe_forward_zero3() + elif model_type == "qwen2": + transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "falcon": + transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) + elif model_type == "phi": + transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) diff --git a/src/axolotl/monkeypatch/phi/__init__.py b/src/axolotl/monkeypatch/phi/__init__.py deleted file mode 100644 index 1076708a0..000000000 --- a/src/axolotl/monkeypatch/phi/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Patches to support multipack for phi2 -""" -import transformers - -from axolotl.monkeypatch.utils import get_unpad_data - - -def replace_phi_attn_with_multipack_flash_attn(): - transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) diff --git a/src/axolotl/monkeypatch/qwen2/__init__.py b/src/axolotl/monkeypatch/qwen2/__init__.py deleted file mode 100644 index 40c54d21e..000000000 --- a/src/axolotl/monkeypatch/qwen2/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Patches to support multipack for qwen2 -""" -import transformers - -from axolotl.monkeypatch.utils import get_unpad_data - - -def replace_qwen2_attn_with_multipack_flash_attn(): - transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access - get_unpad_data - ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index e1748fd7b..2a1507c83 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -29,6 +29,10 @@ from transformers import ( # noqa: F401 from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from axolotl.models.mamba import fix_mamba_attn_for_loss +from axolotl.monkeypatch.multipack import ( + SUPPORTED_MULTIPACK_MODEL_TYPES, + patch_for_multipack, +) from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.chat_templates import chat_templates @@ -299,8 +303,15 @@ def load_model( shifted-sparse attention does not currently support sample packing." ) - # Modify all llama derived models in one block - if cfg.is_llama_derived_model: + if ( + cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES + and cfg.flash_attention + and cfg.sample_packing + ): + patch_for_multipack(cfg.model_config_type) + elif cfg.is_llama_derived_model: + # Modify all llama derived models in one block + if cfg.flash_attention: from axolotl.monkeypatch.llama_attn_hijack_flash import ( replace_llama_attn_with_flash_attn, @@ -354,43 +365,6 @@ def load_model( LOG.info("patching mistral with flash attention") replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) - if ( - cfg.model_config_type == "mixtral" - and cfg.flash_attention - and cfg.sample_packing - ): - from axolotl.monkeypatch.mixtral import ( - replace_mixtral_attn_with_multipack_flash_attn, - ) - - LOG.info("patching mixtral with flash attention") - mixtral_patch_kwargs = {} - if is_deepspeed_zero3_enabled(): - mixtral_patch_kwargs["for_zero3"] = True - replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs) - - if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing: - from axolotl.monkeypatch.falcon import ( - replace_falcon_attn_with_multipack_flash_attn, - ) - - LOG.info("patching falcon with flash attention") - replace_falcon_attn_with_multipack_flash_attn() - - if cfg.model_config_type == "phi" and cfg.flash_attention and cfg.sample_packing: - from axolotl.monkeypatch.phi import replace_phi_attn_with_multipack_flash_attn - - LOG.info("patching phi with flash attention") - replace_phi_attn_with_multipack_flash_attn() - - if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing: - from axolotl.monkeypatch.qwen2 import ( - replace_qwen2_attn_with_multipack_flash_attn, - ) - - LOG.info("patching qwen2 with flash attention") - replace_qwen2_attn_with_multipack_flash_attn() - if cfg.is_llama_derived_model and cfg.sample_packing and not inference: from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask @@ -501,7 +475,7 @@ def load_model( "flash_attention_2" ) else: - if model_config.model_type in ["mixtral", "qwen2", "falcon", "phi"]: + if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES: model_kwargs["attn_implementation"] = "flash_attention_2" model_config._attn_implementation = ( # pylint: disable=protected-access "flash_attention_2"