Implement Mistral SwiGLU
This commit is contained in:
@@ -19,19 +19,29 @@ from transformers.models.mistral.modeling_mistral import (
|
|||||||
)
|
)
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||||
|
MistralMLP
|
||||||
)
|
)
|
||||||
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||||
|
|
||||||
from axolotl.monkeypatch.flash_modules import (
|
from axolotl.monkeypatch.flash_modules import (
|
||||||
flashattn_forward,
|
flashattn_forward,
|
||||||
replace_cross_entropy,
|
replace_cross_entropy,
|
||||||
replace_rms_norm
|
replace_rms_norm
|
||||||
)
|
)
|
||||||
|
from axolotl.monkeypatch.fused_modules import FusedMLP
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
||||||
|
|
||||||
|
def replace_mistral_mlp_with_swiglu(model):
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if isinstance(module, MistralMLP):
|
||||||
|
mlp = FusedMLP(
|
||||||
|
module.config, module.gate_proj, module.up_proj, module.down_proj
|
||||||
|
)
|
||||||
|
set_module_name(model, name, mlp)
|
||||||
|
|
||||||
|
|
||||||
def replace_mistral_attn_with_flash_attn(
|
def replace_mistral_attn_with_flash_attn(
|
||||||
packed: Optional[bool] = False,
|
packed: Optional[bool] = False,
|
||||||
|
|||||||
@@ -278,6 +278,15 @@ def load_model(
|
|||||||
if cfg.flash_attn_fuse_qkv:
|
if cfg.flash_attn_fuse_qkv:
|
||||||
LOG.info("patching with fused QKV")
|
LOG.info("patching with fused QKV")
|
||||||
replace_llama_qkv_with_fused(model)
|
replace_llama_qkv_with_fused(model)
|
||||||
|
elif cfg.is_mistral_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||||
|
if cfg.flash_attention and not inference:
|
||||||
|
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
||||||
|
replace_mistral_mlp_with_swiglu,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.flash_attn_fuse_mlp:
|
||||||
|
LOG.info("patching with SwiGLU")
|
||||||
|
replace_mistral_mlp_with_swiglu(model)
|
||||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
||||||
# This is a WIP, still an issue with the backward pass
|
# This is a WIP, still an issue with the backward pass
|
||||||
# RuntimeError: grad can be implicitly created only for scalar outputs
|
# RuntimeError: grad can be implicitly created only for scalar outputs
|
||||||
|
|||||||
Reference in New Issue
Block a user