Formatting

This commit is contained in:
Casper
2023-12-10 17:15:42 +01:00
parent 23103ac5ac
commit 2ac1a72e4b
2 changed files with 3 additions and 4 deletions

View File

@@ -46,12 +46,11 @@ from transformers.utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from xformers.ops import SwiGLU
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
from .configuration_moe_mistral import MixtralConfig from .configuration_moe_mistral import MixtralConfig
from xformers.ops import SwiGLU
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from flash_attn import ( from flash_attn import (
flash_attn_func, flash_attn_func,

View File

@@ -375,7 +375,7 @@ def load_model(
elif model_type == "MixtralForCausalLM": elif model_type == "MixtralForCausalLM":
from axolotl.models.mixtral import ( from axolotl.models.mixtral import (
MixtralForCausalLM, MixtralForCausalLM,
replace_mixtral_mlp_with_swiglu replace_mixtral_mlp_with_swiglu,
) )
model = MixtralForCausalLM.from_pretrained( model = MixtralForCausalLM.from_pretrained(
@@ -387,7 +387,7 @@ def load_model(
LOG.info("Mixtral MoE: Replacing experts with SwiGLU") LOG.info("Mixtral MoE: Replacing experts with SwiGLU")
replace_mixtral_mlp_with_swiglu(model) replace_mixtral_mlp_with_swiglu(model)
elif model_type == "MambaLMHeadModel": elif model_type == "MambaLMHeadModel":
# FIXME this is janky at best and hacked together to make it work # FIXME this is janky at best and hacked together to make it work
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name