diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6aa4dd162..9feaccc55 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -556,7 +556,7 @@ class ModelLoader: self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name def apply_patches(self) -> None: - if self.cfg.xformers_attention and self.cfg.sample_packing: + if self.cfg.xformers_attention: from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2 patch_xformers_attn_over_fa2() @@ -771,13 +771,6 @@ class ModelLoader: cross_entropy=self.cfg.flash_attn_cross_entropy, rms_norm=self.cfg.flash_attn_rms_norm, ) - elif self.cfg.xformers_attention: - from axolotl.monkeypatch.llama_attn_hijack_xformers import ( - hijack_llama_attention, - ) - - LOG.info("patching with xformers attention") - hijack_llama_attention() elif self.cfg.sample_packing: from axolotl.monkeypatch.llama_patch_multipack import ( hijack_llama_prepare_4d_mask, diff --git a/src/axolotl/utils/schemas/integrations.py b/src/axolotl/utils/schemas/integrations.py index 9d8f9c190..e912d5d90 100644 --- a/src/axolotl/utils/schemas/integrations.py +++ b/src/axolotl/utils/schemas/integrations.py @@ -41,6 +41,7 @@ class WandbConfig(BaseModel): use_wandb: bool | None = None wandb_name: str | None = None wandb_run_id: str | None = None + wandb_run_group: str | None = None wandb_mode: str | None = None wandb_project: str | None = None wandb_entity: str | None = None