swaps to use newer sample packing for mistral (#1773)
* swaps to use newer sample packing for mistral * fix multipack patch test * patch the common fa utils * update for refactor of flash attn unpad * remove un-needed drop attn mask for mistral * bump transformers to main to pick up latest mistral fix for 12b and refactor of fa2 * update test
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
packaging==23.2
|
||||
peft==0.11.1
|
||||
transformers==4.42.4
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf
|
||||
tokenizers==0.19.1
|
||||
bitsandbytes==0.43.1
|
||||
accelerate==0.32.0
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -45,6 +46,15 @@ def replace_mistral_attn_with_flash_attn(
|
||||
)
|
||||
|
||||
|
||||
def patch_mistral_cross_entropy():
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
|
||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||
transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial(
|
||||
CrossEntropyLoss, inplace_backward=True
|
||||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _make_sliding_window_causal_mask(
|
||||
bsz: int,
|
||||
|
||||
@@ -11,6 +11,7 @@ from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"llama",
|
||||
"mistral",
|
||||
"mixtral",
|
||||
"qwen2",
|
||||
"qwen2_moe",
|
||||
@@ -25,6 +26,19 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
|
||||
|
||||
def patch_for_multipack(model_type, model_name=None):
|
||||
if model_type == "gemmoe":
|
||||
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
||||
elif model_type == "deepseek_v2":
|
||||
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
|
||||
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
return
|
||||
|
||||
# retain for legacy
|
||||
if model_type == "mixtral":
|
||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
@@ -32,9 +46,15 @@ def patch_for_multipack(model_type, model_name=None):
|
||||
if is_deepspeed_zero3_enabled():
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
elif model_type == "llama":
|
||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
|
||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "mistral":
|
||||
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
|
||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "qwen2":
|
||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
@@ -63,12 +83,6 @@ def patch_for_multipack(model_type, model_name=None):
|
||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemmoe":
|
||||
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
||||
elif model_type == "jamba":
|
||||
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
|
||||
elif model_type == "deepseek_v2":
|
||||
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
|
||||
|
||||
|
||||
def patch_remote(model_name, config_name, modeling_name):
|
||||
|
||||
@@ -99,48 +99,51 @@ def check_self_attn_is_patchable() -> bool:
|
||||
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
|
||||
|
||||
|
||||
def integrate_cross_entropy_loss_patch():
|
||||
forward = get_forward_code()
|
||||
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
||||
forward, _ = detab_code(forward)
|
||||
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
|
||||
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
|
||||
if model_type == "llama":
|
||||
forward = get_forward_code()
|
||||
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
||||
forward, _ = detab_code(forward)
|
||||
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
|
||||
|
||||
forward = forward.replace(
|
||||
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
|
||||
)
|
||||
forward = forward.replace(
|
||||
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
|
||||
"",
|
||||
)
|
||||
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
|
||||
forward = forward.replace(
|
||||
"def forward(",
|
||||
"def fast_cross_entropy_loss_forward(",
|
||||
1,
|
||||
)
|
||||
forward = forward.replace(
|
||||
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
|
||||
)
|
||||
forward = forward.replace(
|
||||
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
|
||||
"",
|
||||
)
|
||||
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
|
||||
forward = forward.replace(
|
||||
"def forward(",
|
||||
"def fast_cross_entropy_loss_forward(",
|
||||
1,
|
||||
)
|
||||
|
||||
# load imports necessary
|
||||
import transformers.models.llama.modeling_llama
|
||||
# load imports necessary
|
||||
import transformers.models.llama.modeling_llama
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(transformers.models.llama.modeling_llama):
|
||||
if item in forward:
|
||||
items_to_import.append(item)
|
||||
items_to_import = []
|
||||
for item in dir(transformers.models.llama.modeling_llama):
|
||||
if item in forward:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
|
||||
globals(),
|
||||
)
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
|
||||
globals(),
|
||||
)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from transformers.models.llama.modeling_llama import ("
|
||||
+ ", ".join(x for x in items_to_import)
|
||||
+ ")",
|
||||
globals(),
|
||||
)
|
||||
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
|
||||
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from transformers.models.llama.modeling_llama import ("
|
||||
+ ", ".join(x for x in items_to_import)
|
||||
+ ")",
|
||||
globals(),
|
||||
)
|
||||
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
|
||||
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
else:
|
||||
raise ValueError("Unsupported model type")
|
||||
|
||||
|
||||
def detab_code(code: str) -> Tuple[str, str]:
|
||||
|
||||
@@ -367,7 +367,7 @@ def load_model(
|
||||
integrate_cross_entropy_loss_patch,
|
||||
)
|
||||
|
||||
integrate_cross_entropy_loss_patch()
|
||||
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
@@ -424,7 +424,7 @@ def load_model(
|
||||
if cfg.unsloth_cross_entropy_loss:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
||||
|
||||
integrate_cross_entropy_loss_patch()
|
||||
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||
|
||||
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
@@ -432,23 +432,12 @@ def load_model(
|
||||
patch_self_attn_lora()
|
||||
|
||||
# Modify mistral derived models
|
||||
if (
|
||||
cfg.model_config_type == "mistral"
|
||||
and cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
if cfg.model_config_type == "mistral" and cfg.flash_attn_cross_entropy_loss:
|
||||
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
||||
replace_mistral_attn_with_flash_attn,
|
||||
patch_mistral_cross_entropy,
|
||||
)
|
||||
|
||||
LOG.info("patching mistral with flash attention")
|
||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||
|
||||
LOG.info("patching _expand_mask")
|
||||
hijack_expand_mask()
|
||||
patch_mistral_cross_entropy()
|
||||
|
||||
model_kwargs: Dict[str, Any] = {}
|
||||
|
||||
|
||||
@@ -189,9 +189,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
||||
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
||||
|
||||
if (
|
||||
cfg.is_mistral_derived_model and cfg.flash_attention
|
||||
) or cfg.model_config_type == "mamba":
|
||||
if cfg.model_config_type == "mamba":
|
||||
LOG.info("dropping attention_mask column")
|
||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||
if eval_dataset:
|
||||
|
||||
@@ -4,6 +4,8 @@ E2E smoke tests to check that the monkeypatches are in place for certain configu
|
||||
|
||||
import unittest
|
||||
|
||||
import transformers
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -87,9 +89,9 @@ class TestModelPatches(unittest.TestCase):
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||
load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||
|
||||
assert (
|
||||
"axolotl.monkeypatch.mistral_attn_hijack_flash"
|
||||
in model.model.layers[0].self_attn.forward.__module__
|
||||
"torch.jit"
|
||||
in transformers.modeling_flash_attention_utils._get_unpad_data.__module__ # pylint: disable=protected-access
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user