diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 22d988633..c7ac42372 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -175,6 +175,16 @@ class LigerPlugin(BasePlugin): rms_norm=cfg.liger_rms_norm, layer_norm=cfg.liger_layer_norm, ) + elif cfg.model_config_type == "granitemoe": + from liger_kernel.transformers import apply_liger_kernel_to_granite + + apply_liger_kernel_to_granite( + rope=cfg.liger_rope, + cross_entropy=cfg.liger_cross_entropy, + fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, + rms_norm=cfg.liger_rms_norm, + swiglu=cfg.liger_glu_activation, + ) else: logging.warning( f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."