Formatting

This commit is contained in:
Casper
2023-12-10 17:15:42 +01:00
parent 23103ac5ac
commit 2ac1a72e4b
2 changed files with 3 additions and 4 deletions

View File

@@ -46,12 +46,11 @@ from transformers.utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from xformers.ops import SwiGLU
from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name from ...monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
from .configuration_moe_mistral import MixtralConfig from .configuration_moe_mistral import MixtralConfig
from xformers.ops import SwiGLU
if is_flash_attn_2_available(): if is_flash_attn_2_available():
from flash_attn import ( from flash_attn import (
flash_attn_func, flash_attn_func,

View File

@@ -375,7 +375,7 @@ def load_model(
elif model_type == "MixtralForCausalLM": elif model_type == "MixtralForCausalLM":
from axolotl.models.mixtral import ( from axolotl.models.mixtral import (
MixtralForCausalLM, MixtralForCausalLM,
replace_mixtral_mlp_with_swiglu replace_mixtral_mlp_with_swiglu,
) )
model = MixtralForCausalLM.from_pretrained( model = MixtralForCausalLM.from_pretrained(