fix: liger swiglu for llama4 (#2504)
* fix: liger swiglu for llama4 * feat: add liger to deepseek v3 * fix: unpack not found * fix: spelling * fix: comment out deepseek v3 * fix: retest deepseek * fix: map glu * fix: patch model forward * chore: add temp code to save * fix: remove deepseek to move into separate PR
This commit is contained in:
@@ -185,5 +185,7 @@ class LigerPlugin(BasePlugin):
|
|||||||
rms_norm=cfg.liger_rms_norm,
|
rms_norm=cfg.liger_rms_norm,
|
||||||
layer_norm=cfg.liger_layer_norm,
|
layer_norm=cfg.liger_layer_norm,
|
||||||
)
|
)
|
||||||
elif cfg.model_config_type in ["deepseek_v3"]:
|
else:
|
||||||
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")
|
logging.warning(
|
||||||
|
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
||||||
|
)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ Liger FLCE for llama4
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
from copy import deepcopy
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -158,7 +159,16 @@ def apply_liger_kernel_to_llama4(
|
|||||||
if rms_norm:
|
if rms_norm:
|
||||||
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
|
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
|
||||||
if glu_activation:
|
if glu_activation:
|
||||||
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
|
|
||||||
|
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
|
||||||
|
"Accepts intermediate_size to pass to LigerSwiGLUMLP"
|
||||||
|
# clone config to avoid modifying the original
|
||||||
|
config = deepcopy(config)
|
||||||
|
if intermediate_size:
|
||||||
|
setattr(config, "intermediate_size", intermediate_size)
|
||||||
|
return LigerSwiGLUMLP(config, **kwargs)
|
||||||
|
|
||||||
|
modeling_llama4.Llama4TextMLP = _liger_swiglu_mlp_wrapper
|
||||||
if layer_norm:
|
if layer_norm:
|
||||||
modeling_llama4.nn.LayerNorm = LigerLayerNorm
|
modeling_llama4.nn.LayerNorm = LigerLayerNorm
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user