Feat: Allow usage of native Mistral FA when no sample_packing (#669)
* Allow usage of native Mistral FA when no sample_packing * fix: do not apply custom patch when sample_pack off * chore: lint * chore: pin transformer to v4.35.0.dev0 * fix: split sample_packing to separate test
This commit is contained in:
@@ -149,7 +149,7 @@ def load_model(
|
||||
# Note: This might overwrite previous additional_special_tokens
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
||||
|
||||
if cfg.is_mistral_derived_model and cfg.flash_attention:
|
||||
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
|
||||
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
||||
replace_mistral_attn_with_flash_attn,
|
||||
)
|
||||
@@ -200,7 +200,11 @@ def load_model(
|
||||
)
|
||||
# sample packing uses custom FA2 patch
|
||||
if cfg.flash_attention and not cfg.sample_packing:
|
||||
if cfg.is_llama_derived_model or cfg.is_falcon_derived_model:
|
||||
if (
|
||||
cfg.is_llama_derived_model
|
||||
or cfg.is_falcon_derived_model
|
||||
or cfg.is_mistral_derived_model
|
||||
):
|
||||
model_kwargs["use_flash_attention_2"] = True
|
||||
try:
|
||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||
|
||||
Reference in New Issue
Block a user