diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index d58349932..2047f3815 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -20,6 +20,7 @@ It is designed to be performant, correct, and light-weight. """ import logging import sys +from functools import partial from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.geglu import LigerGEGLUMLP @@ -84,7 +85,9 @@ class LigerPlugin(BasePlugin): if cfg.liger_rope: modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb if cfg.liger_rms_norm: - modeling_gemma.GemmaRMSNorm = LigerRMSNorm + modeling_gemma.GemmaRMSNorm = partial( + LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma" + ) if cfg.liger_swiglu: modeling_gemma.GemmaMLP = LigerGEGLUMLP if cfg.liger_cross_entropy: @@ -156,7 +159,9 @@ class LigerPlugin(BasePlugin): if cfg.liger_rope: modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb if cfg.liger_rms_norm: - modeling_gemma2.Gemma2RMSNorm = LigerRMSNorm + modeling_gemma2.Gemma2RMSNorm = partial( + LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma" + ) if cfg.liger_swiglu: modeling_gemma2.Gemma2MLP = LigerGEGLUMLP if cfg.liger_cross_entropy: