Formatting
This commit is contained in:
@@ -46,12 +46,11 @@ from transformers.utils import (
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from xformers.ops import SwiGLU
|
||||
|
||||
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||
from .configuration_moe_mistral import MixtralConfig
|
||||
|
||||
from xformers.ops import SwiGLU
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn import (
|
||||
flash_attn_func,
|
||||
|
||||
@@ -375,7 +375,7 @@ def load_model(
|
||||
elif model_type == "MixtralForCausalLM":
|
||||
from axolotl.models.mixtral import (
|
||||
MixtralForCausalLM,
|
||||
replace_mixtral_mlp_with_swiglu
|
||||
replace_mixtral_mlp_with_swiglu,
|
||||
)
|
||||
|
||||
model = MixtralForCausalLM.from_pretrained(
|
||||
@@ -387,7 +387,7 @@ def load_model(
|
||||
|
||||
LOG.info("Mixtral MoE: Replacing experts with SwiGLU")
|
||||
replace_mixtral_mlp_with_swiglu(model)
|
||||
|
||||
|
||||
elif model_type == "MambaLMHeadModel":
|
||||
# FIXME this is janky at best and hacked together to make it work
|
||||
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
||||
|
||||
Reference in New Issue
Block a user