From f85861a0b222222b9203b0bb201975594e079292 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 9 Apr 2025 13:53:17 +0700 Subject: [PATCH] fix: liger swiglu for llama4 (#2504) * fix: liger swiglu for llama4 * feat: add liger to deepseek v3 * fix: unpack not found * fix: spelling * fix: comment out deepseek v3 * fix: retest deepseek * fix: map glu * fix: patch model forward * chore: add temp code to save * fix: remove deepseek to move into separate PR --- src/axolotl/integrations/liger/__init__.py | 6 ++++-- src/axolotl/integrations/liger/models/llama4.py | 12 +++++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 8d737175e..8e305e0f3 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -185,5 +185,7 @@ class LigerPlugin(BasePlugin): rms_norm=cfg.liger_rms_norm, layer_norm=cfg.liger_layer_norm, ) - elif cfg.model_config_type in ["deepseek_v3"]: - raise ValueError(f"Unsupported model config type: {cfg.model_config_type}") + else: + logging.warning( + f"Unsupported model config type: {cfg.model_config_type}. Liger not applied." + ) diff --git a/src/axolotl/integrations/liger/models/llama4.py b/src/axolotl/integrations/liger/models/llama4.py index da35b114c..689823bb6 100644 --- a/src/axolotl/integrations/liger/models/llama4.py +++ b/src/axolotl/integrations/liger/models/llama4.py @@ -3,6 +3,7 @@ Liger FLCE for llama4 """ import sys +from copy import deepcopy from typing import List, Optional, Tuple, Union import torch @@ -158,7 +159,16 @@ def apply_liger_kernel_to_llama4( if rms_norm: modeling_llama4.Llama4TextRMSNorm = LigerRMSNorm if glu_activation: - modeling_llama4.Llama4TextMLP = LigerSwiGLUMLP + + def _liger_swiglu_mlp_wrapper(config, intermediate_size=None, **kwargs): + "Accepts intermediate_size to pass to LigerSwiGLUMLP" + # clone config to avoid modifying the original + config = deepcopy(config) + if intermediate_size: + setattr(config, "intermediate_size", intermediate_size) + return LigerSwiGLUMLP(config, **kwargs) + + modeling_llama4.Llama4TextMLP = _liger_swiglu_mlp_wrapper if layer_norm: modeling_llama4.nn.LayerNorm = LigerLayerNorm