From 2ac1a72e4b7fc0a68a4a786e2a95ed82f54d1486 Mon Sep 17 00:00:00 2001 From: Casper Date: Sun, 10 Dec 2023 17:15:42 +0100 Subject: [PATCH] Formatting --- src/axolotl/models/mixtral/modeling_moe_mistral.py | 3 +-- src/axolotl/utils/models.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/axolotl/models/mixtral/modeling_moe_mistral.py b/src/axolotl/models/mixtral/modeling_moe_mistral.py index 6e1dc5b24..cc2a81659 100644 --- a/src/axolotl/models/mixtral/modeling_moe_mistral.py +++ b/src/axolotl/models/mixtral/modeling_moe_mistral.py @@ -46,12 +46,11 @@ from transformers.utils import ( logging, replace_return_docstrings, ) +from xformers.ops import SwiGLU from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name from .configuration_moe_mistral import MixtralConfig -from xformers.ops import SwiGLU - if is_flash_attn_2_available(): from flash_attn import ( flash_attn_func, diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 84a4f6cf1..399bd9ff8 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -375,7 +375,7 @@ def load_model( elif model_type == "MixtralForCausalLM": from axolotl.models.mixtral import ( MixtralForCausalLM, - replace_mixtral_mlp_with_swiglu + replace_mixtral_mlp_with_swiglu, ) model = MixtralForCausalLM.from_pretrained( @@ -387,7 +387,7 @@ def load_model( LOG.info("Mixtral MoE: Replacing experts with SwiGLU") replace_mixtral_mlp_with_swiglu(model) - + elif model_type == "MambaLMHeadModel": # FIXME this is janky at best and hacked together to make it work MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name