Compare commits
2 Commits
version-de
...
xformers-w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5e50d1e8f0 | ||
|
|
7fb01f0461 |
@@ -556,7 +556,7 @@ class ModelLoader:
|
|||||||
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||||
|
|
||||||
def apply_patches(self) -> None:
|
def apply_patches(self) -> None:
|
||||||
if self.cfg.xformers_attention and self.cfg.sample_packing:
|
if self.cfg.xformers_attention:
|
||||||
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
|
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
|
||||||
|
|
||||||
patch_xformers_attn_over_fa2()
|
patch_xformers_attn_over_fa2()
|
||||||
@@ -771,13 +771,6 @@ class ModelLoader:
|
|||||||
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
cross_entropy=self.cfg.flash_attn_cross_entropy,
|
||||||
rms_norm=self.cfg.flash_attn_rms_norm,
|
rms_norm=self.cfg.flash_attn_rms_norm,
|
||||||
)
|
)
|
||||||
elif self.cfg.xformers_attention:
|
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
|
||||||
hijack_llama_attention,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("patching with xformers attention")
|
|
||||||
hijack_llama_attention()
|
|
||||||
elif self.cfg.sample_packing:
|
elif self.cfg.sample_packing:
|
||||||
from axolotl.monkeypatch.llama_patch_multipack import (
|
from axolotl.monkeypatch.llama_patch_multipack import (
|
||||||
hijack_llama_prepare_4d_mask,
|
hijack_llama_prepare_4d_mask,
|
||||||
|
|||||||
@@ -475,8 +475,14 @@ class AxolotlInputConfig(
|
|||||||
def check_batch_flattening_fa(cls, data):
|
def check_batch_flattening_fa(cls, data):
|
||||||
if data.get("batch_flattening"):
|
if data.get("batch_flattening"):
|
||||||
batch_flattening_auto = data.get("batch_flattening") == "auto"
|
batch_flattening_auto = data.get("batch_flattening") == "auto"
|
||||||
if not data.get("flash_attention") and not batch_flattening_auto:
|
if (
|
||||||
raise ValueError("batch_flattening requires flash attention")
|
not data.get("flash_attention")
|
||||||
|
and not data.get("xformers_attention")
|
||||||
|
and not batch_flattening_auto
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"batch_flattening requires flash attention or xformers"
|
||||||
|
)
|
||||||
if data.get("sample_packing") and not batch_flattening_auto:
|
if data.get("sample_packing") and not batch_flattening_auto:
|
||||||
raise ValueError("batch_flattening not compatible with sample_packing")
|
raise ValueError("batch_flattening not compatible with sample_packing")
|
||||||
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:
|
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ class WandbConfig(BaseModel):
|
|||||||
use_wandb: bool | None = None
|
use_wandb: bool | None = None
|
||||||
wandb_name: str | None = None
|
wandb_name: str | None = None
|
||||||
wandb_run_id: str | None = None
|
wandb_run_id: str | None = None
|
||||||
|
wandb_run_group: str | None = None
|
||||||
wandb_mode: str | None = None
|
wandb_mode: str | None = None
|
||||||
wandb_project: str | None = None
|
wandb_project: str | None = None
|
||||||
wandb_entity: str | None = None
|
wandb_entity: str | None = None
|
||||||
|
|||||||
Reference in New Issue
Block a user