wire up the patch
This commit is contained in:
@@ -556,6 +556,12 @@ class ModelLoader:
|
|||||||
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||||
|
|
||||||
def apply_patches(self) -> None:
|
def apply_patches(self) -> None:
|
||||||
|
if self.cfg.xformers_attention and self.cfg.sample_packing:
|
||||||
|
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
|
||||||
|
|
||||||
|
patch_xformers_attn_over_fa2()
|
||||||
|
self.cfg.flash_attention = True
|
||||||
|
|
||||||
if self.cfg.chunked_cross_entropy:
|
if self.cfg.chunked_cross_entropy:
|
||||||
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
|
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user