swaps to use newer sample packing for mistral (#1773)

* swaps to use newer sample packing for mistral

* fix multipack patch test

* patch the common fa utils

* update for refactor of flash attn unpad

* remove un-needed drop attn mask for mistral

* bump transformers to main to pick up latest mistral fix for 12b and refactor of fa2

* update test
This commit is contained in:
Wing Lian
2024-07-23 01:41:11 -04:00
committed by GitHub
parent 985819d89b
commit 87455e7f32
7 changed files with 85 additions and 69 deletions

View File

@@ -367,7 +367,7 @@ def load_model(
integrate_cross_entropy_loss_patch,
)
integrate_cross_entropy_loss_patch()
integrate_cross_entropy_loss_patch(model_type="llama")
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
@@ -424,7 +424,7 @@ def load_model(
if cfg.unsloth_cross_entropy_loss:
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
integrate_cross_entropy_loss_patch()
integrate_cross_entropy_loss_patch(model_type="llama")
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
@@ -432,23 +432,12 @@ def load_model(
patch_self_attn_lora()
# Modify mistral derived models
if (
cfg.model_config_type == "mistral"
and cfg.flash_attention
and cfg.sample_packing
):
if cfg.model_config_type == "mistral" and cfg.flash_attn_cross_entropy_loss:
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
replace_mistral_attn_with_flash_attn,
patch_mistral_cross_entropy,
)
LOG.info("patching mistral with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
LOG.info("patching _expand_mask")
hijack_expand_mask()
patch_mistral_cross_entropy()
model_kwargs: Dict[str, Any] = {}