simplify haldning for newer multipack patches so they can be added in a single place (#1270)
This commit is contained in:
@@ -28,6 +28,7 @@ from transformers import (
|
|||||||
from transformers.trainer_utils import seed_worker
|
from transformers.trainer_utils import seed_worker
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
@@ -994,7 +995,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
]
|
]
|
||||||
]
|
]
|
||||||
if use_batch_sampler_collator:
|
if use_batch_sampler_collator:
|
||||||
if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]:
|
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||||
elif (
|
elif (
|
||||||
self.cfg.model_config_type in ["llama"]
|
self.cfg.model_config_type in ["llama"]
|
||||||
|
|||||||
@@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
Patches to support multipack for falcon
|
|
||||||
"""
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_unpad_data
|
|
||||||
|
|
||||||
|
|
||||||
def replace_falcon_attn_with_multipack_flash_attn():
|
|
||||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
@@ -2,9 +2,6 @@
|
|||||||
Patches to support multipack for mixtral
|
Patches to support multipack for mixtral
|
||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_unpad_data
|
|
||||||
|
|
||||||
|
|
||||||
def patch_mixtral_moe_forward_zero3() -> None:
|
def patch_mixtral_moe_forward_zero3() -> None:
|
||||||
@@ -51,11 +48,3 @@ def patch_mixtral_moe_forward_zero3() -> None:
|
|||||||
|
|
||||||
MixtralBLockSparseTop2MLP.forward = mlp_forward
|
MixtralBLockSparseTop2MLP.forward = mlp_forward
|
||||||
MixtralSparseMoeBlock.forward = moe_forward
|
MixtralSparseMoeBlock.forward = moe_forward
|
||||||
|
|
||||||
|
|
||||||
def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
|
|
||||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
if for_zero3:
|
|
||||||
patch_mixtral_moe_forward_zero3()
|
|
||||||
|
|||||||
30
src/axolotl/monkeypatch/multipack.py
Normal file
30
src/axolotl/monkeypatch/multipack.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""multipack patching for v2 of sample packing"""
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
||||||
|
from axolotl.monkeypatch.utils import get_unpad_data
|
||||||
|
|
||||||
|
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"]
|
||||||
|
|
||||||
|
|
||||||
|
def patch_for_multipack(model_type):
|
||||||
|
if model_type == "mixtral":
|
||||||
|
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
patch_mixtral_moe_forward_zero3()
|
||||||
|
elif model_type == "qwen2":
|
||||||
|
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "falcon":
|
||||||
|
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "phi":
|
||||||
|
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
Patches to support multipack for phi2
|
|
||||||
"""
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_unpad_data
|
|
||||||
|
|
||||||
|
|
||||||
def replace_phi_attn_with_multipack_flash_attn():
|
|
||||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
"""
|
|
||||||
Patches to support multipack for qwen2
|
|
||||||
"""
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_unpad_data
|
|
||||||
|
|
||||||
|
|
||||||
def replace_qwen2_attn_with_multipack_flash_attn():
|
|
||||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
|
||||||
get_unpad_data
|
|
||||||
)
|
|
||||||
@@ -29,6 +29,10 @@ from transformers import ( # noqa: F401
|
|||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
|
from axolotl.monkeypatch.multipack import (
|
||||||
|
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||||
|
patch_for_multipack,
|
||||||
|
)
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
@@ -299,8 +303,15 @@ def load_model(
|
|||||||
shifted-sparse attention does not currently support sample packing."
|
shifted-sparse attention does not currently support sample packing."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Modify all llama derived models in one block
|
if (
|
||||||
if cfg.is_llama_derived_model:
|
cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
|
and cfg.flash_attention
|
||||||
|
and cfg.sample_packing
|
||||||
|
):
|
||||||
|
patch_for_multipack(cfg.model_config_type)
|
||||||
|
elif cfg.is_llama_derived_model:
|
||||||
|
# Modify all llama derived models in one block
|
||||||
|
|
||||||
if cfg.flash_attention:
|
if cfg.flash_attention:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||||
replace_llama_attn_with_flash_attn,
|
replace_llama_attn_with_flash_attn,
|
||||||
@@ -354,43 +365,6 @@ def load_model(
|
|||||||
LOG.info("patching mistral with flash attention")
|
LOG.info("patching mistral with flash attention")
|
||||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||||
|
|
||||||
if (
|
|
||||||
cfg.model_config_type == "mixtral"
|
|
||||||
and cfg.flash_attention
|
|
||||||
and cfg.sample_packing
|
|
||||||
):
|
|
||||||
from axolotl.monkeypatch.mixtral import (
|
|
||||||
replace_mixtral_attn_with_multipack_flash_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("patching mixtral with flash attention")
|
|
||||||
mixtral_patch_kwargs = {}
|
|
||||||
if is_deepspeed_zero3_enabled():
|
|
||||||
mixtral_patch_kwargs["for_zero3"] = True
|
|
||||||
replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs)
|
|
||||||
|
|
||||||
if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
|
|
||||||
from axolotl.monkeypatch.falcon import (
|
|
||||||
replace_falcon_attn_with_multipack_flash_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("patching falcon with flash attention")
|
|
||||||
replace_falcon_attn_with_multipack_flash_attn()
|
|
||||||
|
|
||||||
if cfg.model_config_type == "phi" and cfg.flash_attention and cfg.sample_packing:
|
|
||||||
from axolotl.monkeypatch.phi import replace_phi_attn_with_multipack_flash_attn
|
|
||||||
|
|
||||||
LOG.info("patching phi with flash attention")
|
|
||||||
replace_phi_attn_with_multipack_flash_attn()
|
|
||||||
|
|
||||||
if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
|
|
||||||
from axolotl.monkeypatch.qwen2 import (
|
|
||||||
replace_qwen2_attn_with_multipack_flash_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("patching qwen2 with flash attention")
|
|
||||||
replace_qwen2_attn_with_multipack_flash_attn()
|
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
||||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||||
|
|
||||||
@@ -501,7 +475,7 @@ def load_model(
|
|||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if model_config.model_type in ["mixtral", "qwen2", "falcon", "phi"]:
|
if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
"flash_attention_2"
|
"flash_attention_2"
|
||||||
|
|||||||
Reference in New Issue
Block a user