simplify haldning for newer multipack patches so they can be added in a single place (#1270)

This commit is contained in:
Wing Lian
2024-02-07 10:46:04 -05:00
committed by GitHub
parent 411293bdca
commit 5698943263
7 changed files with 46 additions and 88 deletions

View File

@@ -28,6 +28,7 @@ from transformers import (
from transformers.trainer_utils import seed_worker
from trl import DPOTrainer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
@@ -994,7 +995,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
]
]
if use_batch_sampler_collator:
if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
self.cfg.model_config_type in ["llama"]

View File

@@ -1,12 +0,0 @@
"""
Patches to support multipack for falcon
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_falcon_attn_with_multipack_flash_attn():
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -2,9 +2,6 @@
Patches to support multipack for mixtral
"""
import torch
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def patch_mixtral_moe_forward_zero3() -> None:
@@ -51,11 +48,3 @@ def patch_mixtral_moe_forward_zero3() -> None:
MixtralBLockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward
def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if for_zero3:
patch_mixtral_moe_forward_zero3()

View File

@@ -0,0 +1,30 @@
"""multipack patching for v2 of sample packing"""
import transformers
from transformers.integrations import is_deepspeed_zero3_enabled
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"]
def patch_for_multipack(model_type):
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -1,12 +0,0 @@
"""
Patches to support multipack for phi2
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_phi_attn_with_multipack_flash_attn():
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -1,12 +0,0 @@
"""
Patches to support multipack for qwen2
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_qwen2_attn_with_multipack_flash_attn():
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -29,6 +29,10 @@ from transformers import ( # noqa: F401
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack,
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import chat_templates
@@ -299,8 +303,15 @@ def load_model(
shifted-sparse attention does not currently support sample packing."
)
# Modify all llama derived models in one block
if cfg.is_llama_derived_model:
if (
cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and cfg.flash_attention
and cfg.sample_packing
):
patch_for_multipack(cfg.model_config_type)
elif cfg.is_llama_derived_model:
# Modify all llama derived models in one block
if cfg.flash_attention:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
@@ -354,43 +365,6 @@ def load_model(
LOG.info("patching mistral with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if (
cfg.model_config_type == "mixtral"
and cfg.flash_attention
and cfg.sample_packing
):
from axolotl.monkeypatch.mixtral import (
replace_mixtral_attn_with_multipack_flash_attn,
)
LOG.info("patching mixtral with flash attention")
mixtral_patch_kwargs = {}
if is_deepspeed_zero3_enabled():
mixtral_patch_kwargs["for_zero3"] = True
replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs)
if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.falcon import (
replace_falcon_attn_with_multipack_flash_attn,
)
LOG.info("patching falcon with flash attention")
replace_falcon_attn_with_multipack_flash_attn()
if cfg.model_config_type == "phi" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.phi import replace_phi_attn_with_multipack_flash_attn
LOG.info("patching phi with flash attention")
replace_phi_attn_with_multipack_flash_attn()
if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.qwen2 import (
replace_qwen2_attn_with_multipack_flash_attn,
)
LOG.info("patching qwen2 with flash attention")
replace_qwen2_attn_with_multipack_flash_attn()
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
@@ -501,7 +475,7 @@ def load_model(
"flash_attention_2"
)
else:
if model_config.model_type in ["mixtral", "qwen2", "falcon", "phi"]:
if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"