From bdab3ec587f9e85f5684cf57302a1bf93e78de1e Mon Sep 17 00:00:00 2001 From: Chiwan Park Date: Mon, 2 Sep 2024 07:34:24 +0900 Subject: [PATCH] Fix RMSNorm monkey patch for Gemma models (#1886) --- src/axolotl/integrations/liger/__init__.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index d58349932..2047f3815 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -20,6 +20,7 @@ It is designed to be performant, correct, and light-weight. """ import logging import sys +from functools import partial from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.geglu import LigerGEGLUMLP @@ -84,7 +85,9 @@ class LigerPlugin(BasePlugin): if cfg.liger_rope: modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb if cfg.liger_rms_norm: - modeling_gemma.GemmaRMSNorm = LigerRMSNorm + modeling_gemma.GemmaRMSNorm = partial( + LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma" + ) if cfg.liger_swiglu: modeling_gemma.GemmaMLP = LigerGEGLUMLP if cfg.liger_cross_entropy: @@ -156,7 +159,9 @@ class LigerPlugin(BasePlugin): if cfg.liger_rope: modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb if cfg.liger_rms_norm: - modeling_gemma2.Gemma2RMSNorm = LigerRMSNorm + modeling_gemma2.Gemma2RMSNorm = partial( + LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma" + ) if cfg.liger_swiglu: modeling_gemma2.Gemma2MLP = LigerGEGLUMLP if cfg.liger_cross_entropy: