also support xformers w/o packing
This commit is contained in:
@@ -556,7 +556,7 @@ class ModelLoader:
|
||||
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||
|
||||
def apply_patches(self) -> None:
|
||||
if self.cfg.xformers_attention and self.cfg.sample_packing:
|
||||
if self.cfg.xformers_attention:
|
||||
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
|
||||
|
||||
patch_xformers_attn_over_fa2()
|
||||
@@ -771,13 +771,6 @@ class ModelLoader:
|
||||
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
||||
rms_norm=self.cfg.flash_attn_rms_norm,
|
||||
)
|
||||
elif self.cfg.xformers_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_attention,
|
||||
)
|
||||
|
||||
LOG.info("patching with xformers attention")
|
||||
hijack_llama_attention()
|
||||
elif self.cfg.sample_packing:
|
||||
from axolotl.monkeypatch.llama_patch_multipack import (
|
||||
hijack_llama_prepare_4d_mask,
|
||||
|
||||
@@ -41,6 +41,7 @@ class WandbConfig(BaseModel):
|
||||
use_wandb: bool | None = None
|
||||
wandb_name: str | None = None
|
||||
wandb_run_id: str | None = None
|
||||
wandb_run_group: str | None = None
|
||||
wandb_mode: str | None = None
|
||||
wandb_project: str | None = None
|
||||
wandb_entity: str | None = None
|
||||
|
||||
Reference in New Issue
Block a user