Compare commits

...

2 Commits

Author SHA1 Message Date
Wing Lian
5e50d1e8f0 batch flattening with xformers too 2025-05-08 18:23:25 -04:00
Wing Lian
7fb01f0461 also support xformers w/o packing 2025-05-08 15:22:48 -04:00
3 changed files with 10 additions and 10 deletions

View File

@@ -556,7 +556,7 @@ class ModelLoader:
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
def apply_patches(self) -> None:
if self.cfg.xformers_attention and self.cfg.sample_packing:
if self.cfg.xformers_attention:
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
patch_xformers_attn_over_fa2()
@@ -771,13 +771,6 @@ class ModelLoader:
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
)
elif self.cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)
LOG.info("patching with xformers attention")
hijack_llama_attention()
elif self.cfg.sample_packing:
from axolotl.monkeypatch.llama_patch_multipack import (
hijack_llama_prepare_4d_mask,

View File

@@ -475,8 +475,14 @@ class AxolotlInputConfig(
def check_batch_flattening_fa(cls, data):
if data.get("batch_flattening"):
batch_flattening_auto = data.get("batch_flattening") == "auto"
if not data.get("flash_attention") and not batch_flattening_auto:
raise ValueError("batch_flattening requires flash attention")
if (
not data.get("flash_attention")
and not data.get("xformers_attention")
and not batch_flattening_auto
):
raise ValueError(
"batch_flattening requires flash attention or xformers"
)
if data.get("sample_packing") and not batch_flattening_auto:
raise ValueError("batch_flattening not compatible with sample_packing")
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:

View File

@@ -41,6 +41,7 @@ class WandbConfig(BaseModel):
use_wandb: bool | None = None
wandb_name: str | None = None
wandb_run_id: str | None = None
wandb_run_group: str | None = None
wandb_mode: str | None = None
wandb_project: str | None = None
wandb_entity: str | None = None