fix: upgrade liger to 0.5.8 and use native Gemma3 patches (#2527)

* fix: upgrade liger to 0.5.8 and use native Gemma3 patches

* fix: make lint happy

* doc: update Liger Kernel FLCE support for Gemma 3
This commit is contained in:
Chiwan Park
2025-04-19 01:57:40 +09:00
committed by GitHub
parent 60a8f0958d
commit 4ce469d32e
3 changed files with 2 additions and 36 deletions

View File

@@ -6,7 +6,7 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.6
liger-kernel==0.5.8
# END section
packaging==23.2

View File

@@ -25,7 +25,7 @@ liger_fused_linear_cross_entropy: true
- deepseek_v2
- gemma
- gemma2
- gemma3 (partial support, no support for FLCE yet)
- gemma3
- granite
- jamba
- llama

View File

@@ -21,7 +21,6 @@ It is designed to be performant, correct, and light-weight.
import inspect
import logging
import sys
from functools import partial
from axolotl.integrations.base import BasePlugin
@@ -55,7 +54,6 @@ class LigerPlugin(BasePlugin):
)
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.layer_norm import LigerLayerNorm
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
@@ -141,38 +139,6 @@ class LigerPlugin(BasePlugin):
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
if cfg.liger_fused_linear_cross_entropy:
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
elif cfg.model_config_type in ["gemma3", "gemma3_text"]:
from transformers.models.gemma3 import modeling_gemma3
if cfg.liger_rope:
modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
def _liger_rms_norm_wrapper(dim, **kwargs):
"Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm"
return LigerRMSNorm(hidden_size=dim, **kwargs)
modeling_gemma3.Gemma3RMSNorm = partial(
_liger_rms_norm_wrapper,
offset=1.0,
casting_mode="gemma",
init_fn="zeros",
in_place=False,
)
if cfg.liger_glu_activation:
modeling_gemma3.Gemma3MLP = LigerGEGLUMLP
if cfg.liger_layer_norm:
modeling_gemma3.nn.LayerNorm = LigerLayerNorm
if cfg.liger_cross_entropy:
from transformers.loss.loss_utils import nn
nn.functional.cross_entropy = liger_cross_entropy
if cfg.liger_fused_linear_cross_entropy:
raise NotImplementedError(
"Fused linear cross entropy is not yet supported for Gemma3."
)
elif cfg.model_config_type == "llama4":
from axolotl.integrations.liger.models.llama4 import (
apply_liger_kernel_to_llama4,