Implement Mistral SwiGLU
This commit is contained in:
@@ -19,19 +19,29 @@ from transformers.models.mistral.modeling_mistral import (
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||
MistralMLP
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||
|
||||
from axolotl.monkeypatch.flash_modules import (
|
||||
flashattn_forward,
|
||||
replace_cross_entropy,
|
||||
replace_rms_norm
|
||||
)
|
||||
from axolotl.monkeypatch.fused_modules import FusedMLP
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
||||
|
||||
def replace_mistral_mlp_with_swiglu(model):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, MistralMLP):
|
||||
mlp = FusedMLP(
|
||||
module.config, module.gate_proj, module.up_proj, module.down_proj
|
||||
)
|
||||
set_module_name(model, name, mlp)
|
||||
|
||||
|
||||
def replace_mistral_attn_with_flash_attn(
|
||||
packed: Optional[bool] = False,
|
||||
|
||||
@@ -278,6 +278,15 @@ def load_model(
|
||||
if cfg.flash_attn_fuse_qkv:
|
||||
LOG.info("patching with fused QKV")
|
||||
replace_llama_qkv_with_fused(model)
|
||||
elif cfg.is_mistral_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||
if cfg.flash_attention and not inference:
|
||||
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
||||
replace_mistral_mlp_with_swiglu,
|
||||
)
|
||||
|
||||
if cfg.flash_attn_fuse_mlp:
|
||||
LOG.info("patching with SwiGLU")
|
||||
replace_mistral_mlp_with_swiglu(model)
|
||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
||||
# This is a WIP, still an issue with the backward pass
|
||||
# RuntimeError: grad can be implicitly created only for scalar outputs
|
||||
|
||||
Reference in New Issue
Block a user