diff --git a/requirements.txt b/requirements.txt index 76e8be8c1..2827c2ca1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ triton>=3.0.0 mamba-ssm==1.2.0.post1 xformers>=0.0.23.post1 autoawq==0.2.7.post3 -liger-kernel==0.5.6 +liger-kernel==0.5.8 # END section packaging==23.2 diff --git a/src/axolotl/integrations/liger/README.md b/src/axolotl/integrations/liger/README.md index 03422f889..c5cce8282 100644 --- a/src/axolotl/integrations/liger/README.md +++ b/src/axolotl/integrations/liger/README.md @@ -25,7 +25,7 @@ liger_fused_linear_cross_entropy: true - deepseek_v2 - gemma - gemma2 -- gemma3 (partial support, no support for FLCE yet) +- gemma3 - granite - jamba - llama diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 8e305e0f3..4e8d00552 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -21,7 +21,6 @@ It is designed to be performant, correct, and light-weight. import inspect import logging import sys -from functools import partial from axolotl.integrations.base import BasePlugin @@ -55,7 +54,6 @@ class LigerPlugin(BasePlugin): ) from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.functional import liger_cross_entropy - from liger_kernel.transformers.geglu import LigerGEGLUMLP from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN from liger_kernel.transformers.rms_norm import LigerRMSNorm @@ -141,38 +139,6 @@ class LigerPlugin(BasePlugin): modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss if cfg.liger_fused_linear_cross_entropy: modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward - elif cfg.model_config_type in ["gemma3", "gemma3_text"]: - from transformers.models.gemma3 import modeling_gemma3 - - if cfg.liger_rope: - modeling_gemma3.apply_rotary_pos_emb = liger_rotary_pos_emb - if cfg.liger_rms_norm: - - def _liger_rms_norm_wrapper(dim, **kwargs): - "Convert 'dim' keyword to 'hidden_size' to pass to LigerRMSNorm" - return LigerRMSNorm(hidden_size=dim, **kwargs) - - modeling_gemma3.Gemma3RMSNorm = partial( - _liger_rms_norm_wrapper, - offset=1.0, - casting_mode="gemma", - init_fn="zeros", - in_place=False, - ) - if cfg.liger_glu_activation: - modeling_gemma3.Gemma3MLP = LigerGEGLUMLP - if cfg.liger_layer_norm: - modeling_gemma3.nn.LayerNorm = LigerLayerNorm - - if cfg.liger_cross_entropy: - from transformers.loss.loss_utils import nn - - nn.functional.cross_entropy = liger_cross_entropy - - if cfg.liger_fused_linear_cross_entropy: - raise NotImplementedError( - "Fused linear cross entropy is not yet supported for Gemma3." - ) elif cfg.model_config_type == "llama4": from axolotl.integrations.liger.models.llama4 import ( apply_liger_kernel_to_llama4,