Feat: Allow usage of native Mistral FA when no sample_packing (#669)

* Allow usage of native Mistral FA when no sample_packing

* fix: do not apply custom patch when sample_pack off

* chore: lint

* chore: pin transformer to v4.35.0.dev0

* fix: split sample_packing to separate test
This commit is contained in:
NanoCode012
2023-10-04 20:40:47 +09:00
committed by GitHub
parent 90e0d673f7
commit 697c50d408
4 changed files with 125 additions and 95 deletions

View File

@@ -149,7 +149,7 @@ def load_model(
# Note: This might overwrite previous additional_special_tokens
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
if cfg.is_mistral_derived_model and cfg.flash_attention:
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
replace_mistral_attn_with_flash_attn,
)
@@ -200,7 +200,11 @@ def load_model(
)
# sample packing uses custom FA2 patch
if cfg.flash_attention and not cfg.sample_packing:
if cfg.is_llama_derived_model or cfg.is_falcon_derived_model:
if (
cfg.is_llama_derived_model
or cfg.is_falcon_derived_model
or cfg.is_mistral_derived_model
):
model_kwargs["use_flash_attention_2"] = True
try:
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq: