Formatting

This commit is contained in:
Casper
2023-12-10 17:15:42 +01:00
parent 23103ac5ac
commit 2ac1a72e4b
2 changed files with 3 additions and 4 deletions

View File

@@ -46,12 +46,11 @@ from transformers.utils import (
logging,
replace_return_docstrings,
)
from xformers.ops import SwiGLU
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
from .configuration_moe_mistral import MixtralConfig
from xformers.ops import SwiGLU
if is_flash_attn_2_available():
from flash_attn import (
flash_attn_func,

View File

@@ -375,7 +375,7 @@ def load_model(
elif model_type == "MixtralForCausalLM":
from axolotl.models.mixtral import (
MixtralForCausalLM,
replace_mixtral_mlp_with_swiglu
replace_mixtral_mlp_with_swiglu,
)
model = MixtralForCausalLM.from_pretrained(
@@ -387,7 +387,7 @@ def load_model(
LOG.info("Mixtral MoE: Replacing experts with SwiGLU")
replace_mixtral_mlp_with_swiglu(model)
elif model_type == "MambaLMHeadModel":
# FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name