swaps to use newer sample packing for mistral (#1773)
* swaps to use newer sample packing for mistral * fix multipack patch test * patch the common fa utils * update for refactor of flash attn unpad * remove un-needed drop attn mask for mistral * bump transformers to main to pick up latest mistral fix for 12b and refactor of fa2 * update test
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.11.1
|
peft==0.11.1
|
||||||
transformers==4.42.4
|
transformers @ git+https://github.com/huggingface/transformers.git@0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf
|
||||||
tokenizers==0.19.1
|
tokenizers==0.19.1
|
||||||
bitsandbytes==0.43.1
|
bitsandbytes==0.43.1
|
||||||
accelerate==0.32.0
|
accelerate==0.32.0
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from functools import partial
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -45,6 +46,15 @@ def replace_mistral_attn_with_flash_attn(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_mistral_cross_entropy():
|
||||||
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||||
|
|
||||||
|
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||||
|
transformers.models.mistral.modeling_mistral.CrossEntropyLoss = partial(
|
||||||
|
CrossEntropyLoss, inplace_backward=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
def _make_sliding_window_causal_mask(
|
def _make_sliding_window_causal_mask(
|
||||||
bsz: int,
|
bsz: int,
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from axolotl.monkeypatch.utils import get_unpad_data
|
|||||||
|
|
||||||
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||||
"llama",
|
"llama",
|
||||||
|
"mistral",
|
||||||
"mixtral",
|
"mixtral",
|
||||||
"qwen2",
|
"qwen2",
|
||||||
"qwen2_moe",
|
"qwen2_moe",
|
||||||
@@ -25,6 +26,19 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
|
|
||||||
|
|
||||||
def patch_for_multipack(model_type, model_name=None):
|
def patch_for_multipack(model_type, model_name=None):
|
||||||
|
if model_type == "gemmoe":
|
||||||
|
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
||||||
|
elif model_type == "deepseek_v2":
|
||||||
|
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
|
||||||
|
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||||
|
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
if model_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||||
|
patch_mixtral_moe_forward_zero3()
|
||||||
|
return
|
||||||
|
|
||||||
|
# retain for legacy
|
||||||
if model_type == "mixtral":
|
if model_type == "mixtral":
|
||||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
@@ -32,9 +46,15 @@ def patch_for_multipack(model_type, model_name=None):
|
|||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
patch_mixtral_moe_forward_zero3()
|
patch_mixtral_moe_forward_zero3()
|
||||||
elif model_type == "llama":
|
elif model_type == "llama":
|
||||||
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
if hasattr(transformers.models.llama.modeling_llama, "_get_unpad_data"):
|
||||||
get_unpad_data
|
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
)
|
get_unpad_data
|
||||||
|
)
|
||||||
|
elif model_type == "mistral":
|
||||||
|
if hasattr(transformers.models.mistral.modeling_mistral, "_get_unpad_data"):
|
||||||
|
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
elif model_type == "qwen2":
|
elif model_type == "qwen2":
|
||||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
@@ -63,12 +83,6 @@ def patch_for_multipack(model_type, model_name=None):
|
|||||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
elif model_type == "gemmoe":
|
|
||||||
patch_remote(model_name, ".configuration_gemmoe", ".modeling_gemmoe")
|
|
||||||
elif model_type == "jamba":
|
|
||||||
patch_remote(model_name, ".configuration_jamba", ".modeling_jamba")
|
|
||||||
elif model_type == "deepseek_v2":
|
|
||||||
patch_remote(model_name, ".configuration_deepseek", ".modeling_deepseek")
|
|
||||||
|
|
||||||
|
|
||||||
def patch_remote(model_name, config_name, modeling_name):
|
def patch_remote(model_name, config_name, modeling_name):
|
||||||
|
|||||||
@@ -99,48 +99,51 @@ def check_self_attn_is_patchable() -> bool:
|
|||||||
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
|
return ORIGINAL_QKV_CODE in qkv and ORIGINAL_O_CODE in qkv
|
||||||
|
|
||||||
|
|
||||||
def integrate_cross_entropy_loss_patch():
|
def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
|
||||||
forward = get_forward_code()
|
if model_type == "llama":
|
||||||
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
forward = get_forward_code()
|
||||||
forward, _ = detab_code(forward)
|
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
|
||||||
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
|
forward, _ = detab_code(forward)
|
||||||
|
assert ORIGINAL_CEL_CODE in forward, "Original forward code not found"
|
||||||
|
|
||||||
forward = forward.replace(
|
forward = forward.replace(
|
||||||
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
|
"@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)", ""
|
||||||
)
|
)
|
||||||
forward = forward.replace(
|
forward = forward.replace(
|
||||||
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
|
"@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)",
|
||||||
"",
|
"",
|
||||||
)
|
)
|
||||||
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
|
forward = forward.replace(ORIGINAL_CEL_CODE, PATCHED_CEL_CODE)
|
||||||
forward = forward.replace(
|
forward = forward.replace(
|
||||||
"def forward(",
|
"def forward(",
|
||||||
"def fast_cross_entropy_loss_forward(",
|
"def fast_cross_entropy_loss_forward(",
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# load imports necessary
|
# load imports necessary
|
||||||
import transformers.models.llama.modeling_llama
|
import transformers.models.llama.modeling_llama
|
||||||
|
|
||||||
items_to_import = []
|
items_to_import = []
|
||||||
for item in dir(transformers.models.llama.modeling_llama):
|
for item in dir(transformers.models.llama.modeling_llama):
|
||||||
if item in forward:
|
if item in forward:
|
||||||
items_to_import.append(item)
|
items_to_import.append(item)
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
|
"from unsloth.kernels.cross_entropy_loss import fast_cross_entropy_loss",
|
||||||
globals(),
|
globals(),
|
||||||
)
|
)
|
||||||
|
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
"from transformers.models.llama.modeling_llama import ("
|
"from transformers.models.llama.modeling_llama import ("
|
||||||
+ ", ".join(x for x in items_to_import)
|
+ ", ".join(x for x in items_to_import)
|
||||||
+ ")",
|
+ ")",
|
||||||
globals(),
|
globals(),
|
||||||
)
|
)
|
||||||
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||||
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
|
LOG.info("patching unsloth fast_cross_entropy_loss", main_process_only=True)
|
||||||
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
|
LlamaForCausalLM.forward = fast_cross_entropy_loss_forward # pylint: disable=undefined-variable # noqa: F821
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported model type")
|
||||||
|
|
||||||
|
|
||||||
def detab_code(code: str) -> Tuple[str, str]:
|
def detab_code(code: str) -> Tuple[str, str]:
|
||||||
|
|||||||
@@ -367,7 +367,7 @@ def load_model(
|
|||||||
integrate_cross_entropy_loss_patch,
|
integrate_cross_entropy_loss_patch,
|
||||||
)
|
)
|
||||||
|
|
||||||
integrate_cross_entropy_loss_patch()
|
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||||
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
||||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||||
|
|
||||||
@@ -424,7 +424,7 @@ def load_model(
|
|||||||
if cfg.unsloth_cross_entropy_loss:
|
if cfg.unsloth_cross_entropy_loss:
|
||||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
||||||
|
|
||||||
integrate_cross_entropy_loss_patch()
|
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||||
|
|
||||||
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
||||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||||
@@ -432,23 +432,12 @@ def load_model(
|
|||||||
patch_self_attn_lora()
|
patch_self_attn_lora()
|
||||||
|
|
||||||
# Modify mistral derived models
|
# Modify mistral derived models
|
||||||
if (
|
if cfg.model_config_type == "mistral" and cfg.flash_attn_cross_entropy_loss:
|
||||||
cfg.model_config_type == "mistral"
|
|
||||||
and cfg.flash_attention
|
|
||||||
and cfg.sample_packing
|
|
||||||
):
|
|
||||||
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
||||||
replace_mistral_attn_with_flash_attn,
|
patch_mistral_cross_entropy,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.info("patching mistral with flash attention")
|
patch_mistral_cross_entropy()
|
||||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
|
||||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
|
||||||
|
|
||||||
LOG.info("patching _expand_mask")
|
|
||||||
hijack_expand_mask()
|
|
||||||
|
|
||||||
model_kwargs: Dict[str, Any] = {}
|
model_kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|||||||
@@ -189,9 +189,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
||||||
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
||||||
|
|
||||||
if (
|
if cfg.model_config_type == "mamba":
|
||||||
cfg.is_mistral_derived_model and cfg.flash_attention
|
|
||||||
) or cfg.model_config_type == "mamba":
|
|
||||||
LOG.info("dropping attention_mask column")
|
LOG.info("dropping attention_mask column")
|
||||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ E2E smoke tests to check that the monkeypatches are in place for certain configu
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -87,9 +89,9 @@ class TestModelPatches(unittest.TestCase):
|
|||||||
normalize_config(cfg)
|
normalize_config(cfg)
|
||||||
cli_args = TrainerCliArgs()
|
cli_args = TrainerCliArgs()
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
|
load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
"axolotl.monkeypatch.mistral_attn_hijack_flash"
|
"torch.jit"
|
||||||
in model.model.layers[0].self_attn.forward.__module__
|
in transformers.modeling_flash_attention_utils._get_unpad_data.__module__ # pylint: disable=protected-access
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user