flash_attention + sample packing for stablelm 3b (#671)

* stablelm epoch fa patch

* is causal for fa

* working stablelm fa w packing

* chore: pre-commit linting
This commit is contained in:
Wing Lian
2023-10-05 16:03:43 -04:00
committed by GitHub
parent eb480dfd68
commit 2d60ba3a6e
3 changed files with 429 additions and 1 deletions

View File

@@ -124,6 +124,17 @@ def load_model(
replace_btlm_attn_with_flash_attn(cfg.base_model)
if (
hasattr(model_config, "model_type")
and model_config.model_type == "stablelm_epoch"
):
if cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.stablelm_attn_hijack_flash import (
replace_stablelm_attn_with_flash_attn,
)
replace_stablelm_attn_with_flash_attn(cfg.base_model)
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (