fix: liger swiglu for llama4 (#2504)
* fix: liger swiglu for llama4 * feat: add liger to deepseek v3 * fix: unpack not found * fix: spelling * fix: comment out deepseek v3 * fix: retest deepseek * fix: map glu * fix: patch model forward * chore: add temp code to save * fix: remove deepseek to move into separate PR
This commit is contained in:
@@ -185,5 +185,7 @@ class LigerPlugin(BasePlugin):
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
elif cfg.model_config_type in ["deepseek_v3"]:
|
||||
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")
|
||||
else:
|
||||
logging.warning(
|
||||
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ Liger FLCE for llama4
|
||||
"""
|
||||
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -158,7 +159,16 @@ def apply_liger_kernel_to_llama4(
|
||||
if rms_norm:
|
||||
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
|
||||
if glu_activation:
|
||||
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
|
||||
|
||||
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
|
||||
"Accepts intermediate_size to pass to LigerSwiGLUMLP"
|
||||
# clone config to avoid modifying the original
|
||||
config = deepcopy(config)
|
||||
if intermediate_size:
|
||||
setattr(config, "intermediate_size", intermediate_size)
|
||||
return LigerSwiGLUMLP(config, **kwargs)
|
||||
|
||||
modeling_llama4.Llama4TextMLP = _liger_swiglu_mlp_wrapper
|
||||
if layer_norm:
|
||||
modeling_llama4.nn.LayerNorm = LigerLayerNorm
|
||||
|
||||
|
||||
Reference in New Issue
Block a user