From 5f58555bd0dbf15cae25fc021eb00421e53e47b2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 16 Jul 2024 17:36:29 -0400 Subject: [PATCH] support for llama multipack using updated code/patches (#1754) * support for llama multipack using updated code/patches * also support unsloth patches * incorrect arg * add config validation for unsloth * add missing return to validation * add another missing return to validation --- .../monkeypatch/llama_attn_hijack_flash.py | 50 +++++++++++-------- src/axolotl/monkeypatch/multipack.py | 5 ++ .../config/models/input/v0_4_1/__init__.py | 40 +++++++++++++++ src/axolotl/utils/models.py | 21 ++++++++ 4 files changed, 95 insertions(+), 21 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 9377cb03f..4c3571ea4 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -78,6 +78,33 @@ def replace_llama_qkv_with_fused(model): set_module_name(model, name, qkv) +def patch_llama_cross_entropy(): + from flash_attn.losses.cross_entropy import CrossEntropyLoss + + LOG.info("patching with flash_attn.losses.cross_entropy") + transformers.models.llama.modeling_llama.CrossEntropyLoss = partial( + CrossEntropyLoss, inplace_backward=True + ) + + +def patch_llama_rms_norm(): + try: + from flash_attn.ops.rms_norm import RMSNorm + + class LlamaRMSNorm(RMSNorm): + """Patched LLamaRMSNorm""" + + def __init__(self, hidden_size, eps=1e-6): + super().__init__(hidden_size, eps=eps) + + LOG.info("patching with flash_attn.ops.rms_norm") + transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm + except ImportError: + LOG.warning( + "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)" + ) + + def replace_llama_attn_with_flash_attn( packed: Optional[bool] = False, cross_entropy: Optional[bool] = False, @@ -104,30 +131,11 @@ def replace_llama_attn_with_flash_attn( # skip only if explicitly disabled if cross_entropy: - from flash_attn.losses.cross_entropy import CrossEntropyLoss - - LOG.info("patching with flash_attn.losses.cross_entropy") - transformers.models.llama.modeling_llama.CrossEntropyLoss = partial( - CrossEntropyLoss, inplace_backward=True - ) + patch_llama_cross_entropy() # skip only if explicitly disabled if rms_norm: - try: - from flash_attn.ops.rms_norm import RMSNorm - - class LlamaRMSNorm(RMSNorm): - """Patched LLamaRMSNorm""" - - def __init__(self, hidden_size, eps=1e-6): - super().__init__(hidden_size, eps=eps) - - LOG.info("patching with flash_attn.ops.rms_norm") - transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm - except ImportError: - LOG.warning( - "optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)" - ) + patch_llama_rms_norm() class FusedAttention(LlamaAttention): diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index e319596d0..017adb2bf 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -10,6 +10,7 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3 from axolotl.monkeypatch.utils import get_unpad_data SUPPORTED_MULTIPACK_MODEL_TYPES = [ + "llama", "mixtral", "qwen2", "qwen2_moe", @@ -30,6 +31,10 @@ def patch_for_multipack(model_type, model_name=None): ) if is_deepspeed_zero3_enabled(): patch_mixtral_moe_forward_zero3() + elif model_type == "llama": + transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) elif model_type == "qwen2": transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 3d0b02752..6cd98af11 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1112,6 +1112,31 @@ class AxolotlInputConfig( raise ValueError("either datasets or pretraining_dataset is required") return data + @model_validator(mode="before") + @classmethod + def check_xentropy_patch_conflicts(cls, data): + if data.get("flash_attn_cross_entropy") and data.get( + "unsloth_cross_entropy_loss" + ): + raise ValueError( + "flash_attn_cross_entropy and unsloth_cross_entropy_loss cannot be both enabled" + ) + return data + + @model_validator(mode="before") + @classmethod + def check_qlora_unsloth(cls, data): + if ( + data.get("unsloth_lora_mlp") + or data.get("unsloth_lora_qkv") + or data.get("unsloth_lora_o") + ): + if data.get("adapter") == "lora" or data.get("load_in_8bit"): + raise ValueError( + "unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with 8-bit LoRA" + ) + return data + class AxolotlConfigWCapabilities(AxolotlInputConfig): """wrapper to valdiate gpu capabilities with the configured options""" @@ -1163,3 +1188,18 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): if data.get("deepspeed") and data.get("fsdp"): raise ValueError("deepspeed and fsdp cannot be used together.") return data + + @model_validator(mode="before") + @classmethod + def check_multigpu_unsloth(cls, data): + if ( + data.get("unsloth_lora_mlp") + or data.get("unsloth_lora_qkv") + or data.get("unsloth_lora_o") + ): + capabilities = data.get("capabilities") + if capabilities and capabilities.get("num_gpus") > 1: + raise ValueError( + "unsloth_lora_mlp, unsloth_lora_qkv, and unsloth_lora_o are not compatible with multi-GPU training." + ) + return data diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 19745ef8b..51ce5a29b 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -347,6 +347,27 @@ def load_model( and cfg.sample_packing ): patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model) + + if cfg.is_llama_derived_model: + from axolotl.monkeypatch.llama_attn_hijack_flash import ( + patch_llama_cross_entropy, + patch_llama_rms_norm, + ) + + if cfg.flash_attn_cross_entropy: + patch_llama_cross_entropy() + if cfg.flash_attn_rms_norm: + patch_llama_rms_norm() + if cfg.unsloth_cross_entropy_loss: + from axolotl.monkeypatch.unsloth_ import ( + integrate_cross_entropy_loss_patch, + ) + + integrate_cross_entropy_loss_patch() + if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: + from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora + + patch_self_attn_lora() elif cfg.is_llama_derived_model: # Modify all llama derived models in one block