diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 8d737175e..8e305e0f3 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -185,5 +185,7 @@ class LigerPlugin(BasePlugin): rms_norm=cfg.liger_rms_norm, layer_norm=cfg.liger_layer_norm, ) - elif cfg.model_config_type in ["deepseek_v3"]: - raise ValueError(f"Unsupported model config type: {cfg.model_config_type}") + else: + logging.warning( + f"Unsupported model config type: {cfg.model_config_type}. Liger not applied." + ) diff --git a/src/axolotl/integrations/liger/models/llama4.py b/src/axolotl/integrations/liger/models/llama4.py index da35b114c..689823bb6 100644 --- a/src/axolotl/integrations/liger/models/llama4.py +++ b/src/axolotl/integrations/liger/models/llama4.py @@ -3,6 +3,7 @@ Liger FLCE for llama4 """ import sys +from copy import deepcopy from typing import List, Optional, Tuple, Union import torch @@ -158,7 +159,16 @@ def apply_liger_kernel_to_llama4( if rms_norm: modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm if glu_activation: - modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP + + def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs): + "Accepts intermediate_size to pass to LigerSwiGLUMLP" + # clone config to avoid modifying the original + config = deepcopy(config) + if intermediate_size: + setattr(config, "intermediate_size", intermediate_size) + return LigerSwiGLUMLP(config, **kwargs) + + modeling_llama4.Llama4TextMLP = _liger_swiglu_mlp_wrapper if layer_norm: modeling_llama4.nn.LayerNorm = LigerLayerNorm