Fix RMSNorm monkey patch for Gemma models (#1886)

This commit is contained in:
Chiwan Park
2024-09-02 07:34:24 +09:00
committed by GitHub
parent 3c6b9eda2e
commit bdab3ec587

View File

@@ -20,6 +20,7 @@ It is designed to be performant, correct, and light-weight.
"""
import logging
import sys
from functools import partial
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.geglu import LigerGEGLUMLP
@@ -84,7 +85,9 @@ class LigerPlugin(BasePlugin):
if cfg.liger_rope:
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma.GemmaRMSNorm = LigerRMSNorm
modeling_gemma.GemmaRMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
)
if cfg.liger_swiglu:
modeling_gemma.GemmaMLP = LigerGEGLUMLP
if cfg.liger_cross_entropy:
@@ -156,7 +159,9 @@ class LigerPlugin(BasePlugin):
if cfg.liger_rope:
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
if cfg.liger_rms_norm:
modeling_gemma2.Gemma2RMSNorm = LigerRMSNorm
modeling_gemma2.Gemma2RMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
)
if cfg.liger_swiglu:
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
if cfg.liger_cross_entropy: