fix: liger swiglu for llama4 (#2504)

* fix: liger swiglu for llama4

* feat: add liger to deepseek v3

* fix: unpack not found

* fix: spelling

* fix: comment out deepseek v3

* fix: retest deepseek

* fix: map glu

* fix: patch model forward

* chore: add temp code to save

* fix: remove deepseek to move into separate PR
This commit is contained in:
NanoCode012
2025-04-09 13:53:17 +07:00
committed by GitHub
parent 630e40dd13
commit f85861a0b2
2 changed files with 15 additions and 3 deletions

View File

@@ -185,5 +185,7 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type in ["deepseek_v3"]:
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")
else:
logging.warning(
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
)

View File

@@ -3,6 +3,7 @@ Liger FLCE for llama4
"""
import sys
from copy import deepcopy
from typing import List, Optional, Tuple, Union
import torch
@@ -158,7 +159,16 @@ def apply_liger_kernel_to_llama4(
if rms_norm:
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
if glu_activation:
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
"Accepts intermediate_size to pass to LigerSwiGLUMLP"
# clone config to avoid modifying the original
config = deepcopy(config)
if intermediate_size:
setattr(config, "intermediate_size", intermediate_size)
return LigerSwiGLUMLP(config, **kwargs)
modeling_llama4.Llama4TextMLP = _liger_swiglu_mlp_wrapper
if layer_norm:
modeling_llama4.nn.LayerNorm = LigerLayerNorm