swaps to use newer sample packing for mistral (#1773)
* swaps to use newer sample packing for mistral * fix multipack patch test * patch the common fa utils * update for refactor of flash attn unpad * remove un-needed drop attn mask for mistral * bump transformers to main to pick up latest mistral fix for 12b and refactor of fa2 * update test
This commit is contained in:
@@ -367,7 +367,7 @@ def load_model(
|
||||
integrate_cross_entropy_loss_patch,
|
||||
)
|
||||
|
||||
integrate_cross_entropy_loss_patch()
|
||||
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
|
||||
@@ -424,7 +424,7 @@ def load_model(
|
||||
if cfg.unsloth_cross_entropy_loss:
|
||||
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch
|
||||
|
||||
integrate_cross_entropy_loss_patch()
|
||||
integrate_cross_entropy_loss_patch(model_type="llama")
|
||||
|
||||
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora
|
||||
@@ -432,23 +432,12 @@ def load_model(
|
||||
patch_self_attn_lora()
|
||||
|
||||
# Modify mistral derived models
|
||||
if (
|
||||
cfg.model_config_type == "mistral"
|
||||
and cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
if cfg.model_config_type == "mistral" and cfg.flash_attn_cross_entropy_loss:
|
||||
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
||||
replace_mistral_attn_with_flash_attn,
|
||||
patch_mistral_cross_entropy,
|
||||
)
|
||||
|
||||
LOG.info("patching mistral with flash attention")
|
||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||
|
||||
LOG.info("patching _expand_mask")
|
||||
hijack_expand_mask()
|
||||
patch_mistral_cross_entropy()
|
||||
|
||||
model_kwargs: Dict[str, Any] = {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user