Fix RMSNorm monkey patch for Gemma models (#1886)
This commit is contained in:
@@ -20,6 +20,7 @@ It is designed to be performant, correct, and light-weight.
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||||
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
from liger_kernel.transformers.geglu import LigerGEGLUMLP
|
||||||
@@ -84,7 +85,9 @@ class LigerPlugin(BasePlugin):
|
|||||||
if cfg.liger_rope:
|
if cfg.liger_rope:
|
||||||
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||||
if cfg.liger_rms_norm:
|
if cfg.liger_rms_norm:
|
||||||
modeling_gemma.GemmaRMSNorm = LigerRMSNorm
|
modeling_gemma.GemmaRMSNorm = partial(
|
||||||
|
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
||||||
|
)
|
||||||
if cfg.liger_swiglu:
|
if cfg.liger_swiglu:
|
||||||
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
modeling_gemma.GemmaMLP = LigerGEGLUMLP
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
@@ -156,7 +159,9 @@ class LigerPlugin(BasePlugin):
|
|||||||
if cfg.liger_rope:
|
if cfg.liger_rope:
|
||||||
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb
|
||||||
if cfg.liger_rms_norm:
|
if cfg.liger_rms_norm:
|
||||||
modeling_gemma2.Gemma2RMSNorm = LigerRMSNorm
|
modeling_gemma2.Gemma2RMSNorm = partial(
|
||||||
|
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
|
||||||
|
)
|
||||||
if cfg.liger_swiglu:
|
if cfg.liger_swiglu:
|
||||||
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
modeling_gemma2.Gemma2MLP = LigerGEGLUMLP
|
||||||
if cfg.liger_cross_entropy:
|
if cfg.liger_cross_entropy:
|
||||||
|
|||||||
Reference in New Issue
Block a user