Formatting
This commit is contained in:
@@ -46,12 +46,11 @@ from transformers.utils import (
|
|||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
|
from xformers.ops import SwiGLU
|
||||||
|
|
||||||
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||||
from .configuration_moe_mistral import MixtralConfig
|
from .configuration_moe_mistral import MixtralConfig
|
||||||
|
|
||||||
from xformers.ops import SwiGLU
|
|
||||||
|
|
||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
from flash_attn import (
|
from flash_attn import (
|
||||||
flash_attn_func,
|
flash_attn_func,
|
||||||
|
|||||||
@@ -375,7 +375,7 @@ def load_model(
|
|||||||
elif model_type == "MixtralForCausalLM":
|
elif model_type == "MixtralForCausalLM":
|
||||||
from axolotl.models.mixtral import (
|
from axolotl.models.mixtral import (
|
||||||
MixtralForCausalLM,
|
MixtralForCausalLM,
|
||||||
replace_mixtral_mlp_with_swiglu
|
replace_mixtral_mlp_with_swiglu,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = MixtralForCausalLM.from_pretrained(
|
model = MixtralForCausalLM.from_pretrained(
|
||||||
@@ -387,7 +387,7 @@ def load_model(
|
|||||||
|
|
||||||
LOG.info("Mixtral MoE: Replacing experts with SwiGLU")
|
LOG.info("Mixtral MoE: Replacing experts with SwiGLU")
|
||||||
replace_mixtral_mlp_with_swiglu(model)
|
replace_mixtral_mlp_with_swiglu(model)
|
||||||
|
|
||||||
elif model_type == "MambaLMHeadModel":
|
elif model_type == "MambaLMHeadModel":
|
||||||
# FIXME this is janky at best and hacked together to make it work
|
# FIXME this is janky at best and hacked together to make it work
|
||||||
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
||||||
|
|||||||
Reference in New Issue
Block a user