Merge branch 'main' into flex_patching_update

This commit is contained in:
Salman Mohammadi
2025-04-09 16:51:05 +01:00
6 changed files with 149 additions and 27 deletions

View File

@@ -185,5 +185,7 @@ class LigerPlugin(BasePlugin):
rms_norm=cfg.liger_rms_norm,
layer_norm=cfg.liger_layer_norm,
)
elif cfg.model_config_type in ["deepseek_v3"]:
raise ValueError(f"Unsupported model config type: {cfg.model_config_type}")
else:
logging.warning(
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
)

View File

@@ -3,6 +3,7 @@ Liger FLCE for llama4
"""
import sys
from copy import deepcopy
from typing import List, Optional, Tuple, Union
import torch
@@ -158,7 +159,16 @@ def apply_liger_kernel_to_llama4(
if rms_norm:
modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm
if glu_activation:
modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP
def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs):
"Accepts intermediate_size to pass to LigerSwiGLUMLP"
# clone config to avoid modifying the original
config = deepcopy(config)
if intermediate_size:
setattr(config, "intermediate_size", intermediate_size)
return LigerSwiGLUMLP(config, **kwargs)
modeling_llama4.Llama4TextMLP = _liger_swiglu_mlp_wrapper
if layer_norm:
modeling_llama4.nn.LayerNorm = LigerLayerNorm