swaps to use newer sample packing for mistral (#1773)

* swaps to use newer sample packing for mistral

* fix multipack patch test

* patch the common fa utils

* update for refactor of flash attn unpad

* remove un-needed drop attn mask for mistral

* bump transformers to main to pick up latest mistral fix for 12b and refactor of fa2

* update test
This commit is contained in:
Wing Lian
2024-07-23 01:41:11 -04:00
committed by GitHub
parent 985819d89b
commit 87455e7f32
7 changed files with 85 additions and 69 deletions

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.11.1 peft==0.11.1
transformers==4.42.4 transformers @ git+https://github.com/huggingface/transformers.git@0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf
tokenizers==0.19.1 tokenizers==0.19.1
bitsandbytes==0.43.1 bitsandbytes==0.43.1
accelerate==0.32.0 accelerate==0.32.0

View File

@@ -2,6 +2,7 @@
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import logging import logging
from functools import partial
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
@@ -45,6 +46,15 @@ def replace_mistral_attn_with_flash_attn(
) )
def patch_mistral_cross_entropy():
from flash_attn.losses.cross_entropy import CrossEntropyLoss
LOG.info("patching with flash_attn.losses.cross_entropy")
transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial(
CrossEntropyLoss, inplace_backward=True
)
@torch.jit.script @torch.jit.script
def _make_sliding_window_causal_mask( def _make_sliding_window_causal_mask(
bsz: int, bsz: int,

View File

@@ -11,6 +11,7 @@ from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"llama", "llama",
"mistral",
"mixtral", "mixtral",
"qwen2", "qwen2",
"qwen2_moe", "qwen2_moe",
@@ -25,6 +26,19 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
def patch_for_multipack(model_type, model_name=None): def patch_for_multipack(model_type, model_name=None):
if model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
elif hasattr(transformers, "modeling_flash_attention_utils"):
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
return
# retain for legacy
if model_type == "mixtral": if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data get_unpad_data
@@ -32,9 +46,15 @@ def patch_for_multipack(model_type, model_name=None):
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3() patch_mixtral_moe_forward_zero3()
elif model_type == "llama": elif model_type == "llama":
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
get_unpad_data transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
) get_unpad_data
)
elif model_type == "mistral":
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2": elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data get_unpad_data
@@ -63,12 +83,6 @@ def patch_for_multipack(model_type, model_name=None):
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data get_unpad_data
) )
elif model_type == "gemmoe":
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
elif model_type == "jamba":
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
elif model_type == "deepseek_v2":
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
def patch_remote(model_name, config_name, modeling_name): def patch_remote(model_name, config_name, modeling_name):

View File

@@ -99,48 +99,51 @@ def check_self_attn_is_patchable() -> bool:
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
def integrate_cross_entropy_loss_patch(): def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
forward = get_forward_code() if model_type == "llama":
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access forward = get_forward_code()
forward, _ = detab_code(forward) LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found" forward, _ = detab_code(forward)
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
forward = forward.replace( forward = forward.replace(
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", "" "@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
) )
forward = forward.replace( forward = forward.replace(
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)", "@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
"", "",
) )
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE) forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
forward = forward.replace( forward = forward.replace(
"def forward(", "def forward(",
"def fast_cross_entropy_loss_forward(", "def fast_cross_entropy_loss_forward(",
1, 1,
) )
# load imports necessary # load imports necessary
import transformers.models.llama.modeling_llama import transformers.models.llama.modeling_llama
items_to_import = [] items_to_import = []
for item in dir(transformers.models.llama.modeling_llama): for item in dir(transformers.models.llama.modeling_llama):
if item in forward: if item in forward:
items_to_import.append(item) items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102 exec( # pylint: disable=exec-used # nosec B102
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss", "from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
globals(), globals(),
) )
exec( # pylint: disable=exec-used # nosec B102 exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import (" "from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import) + ", ".join(x for x in items_to_import)
+ ")", + ")",
globals(), globals(),
) )
exec(forward, globals()) # pylint: disable=exec-used # nosec B102 exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True) LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821 LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
else:
raise ValueError("Unsupported model type")
def detab_code(code: str) -> Tuple[str, str]: def detab_code(code: str) -> Tuple[str, str]:

View File

@@ -367,7 +367,7 @@ def load_model(
integrate_cross_entropy_loss_patch, integrate_cross_entropy_loss_patch,
) )
integrate_cross_entropy_loss_patch() integrate_cross_entropy_loss_patch(model_type="llama")
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
@@ -424,7 +424,7 @@ def load_model(
if cfg.unsloth_cross_entropy_loss: if cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch() integrate_cross_entropy_loss_patch(model_type="llama")
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
@@ -432,23 +432,12 @@ def load_model(
patch_self_attn_lora() patch_self_attn_lora()
# Modify mistral derived models # Modify mistral derived models
if ( if cfg.model_config_type == "mistral" and cfg.flash_attn_cross_entropy_loss:
cfg.model_config_type == "mistral"
and cfg.flash_attention
and cfg.sample_packing
):
from axolotl.monkeypatch.mistral_attn_hijack_flash import ( from axolotl.monkeypatch.mistral_attn_hijack_flash import (
replace_mistral_attn_with_flash_attn, patch_mistral_cross_entropy,
) )
LOG.info("patching mistral with flash attention") patch_mistral_cross_entropy()
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
LOG.info("patching _expand_mask")
hijack_expand_mask()
model_kwargs: Dict[str, Any] = {} model_kwargs: Dict[str, Any] = {}

View File

@@ -189,9 +189,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
max_input_len = np.max(get_dataset_lengths(train_dataset)) max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
if ( if cfg.model_config_type == "mamba":
cfg.is_mistral_derived_model and cfg.flash_attention
) or cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column") LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask") train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset: if eval_dataset:

View File

@@ -4,6 +4,8 @@ E2E smoke tests to check that the monkeypatches are in place for certain configu
import unittest import unittest
import transformers
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -87,9 +89,9 @@ class TestModelPatches(unittest.TestCase):
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference) load_model(cfg, tokenizer, inference=cli_args.inference)
assert ( assert (
"axolotl.monkeypatch.mistral_attn_hijack_flash" "torch.jit"
in model.model.layers[0].self_attn.forward.__module__ in transformers.modeling_flash_attention_utils._get_unpad_data.__module__ # pylint: disable=protected-access
) )