From 15408d0f09062463a50d39cb0ff356fa1590a2b3 Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Sat, 31 Aug 2024 18:59:48 -0700 Subject: [PATCH] Update supported models for Liger Kernel (#1875) * Update supported models for Liger Kernel Add Mistral LCE, Gemma LCE, Gemma 2 without LCE (softcapping is not yet implemented for Gemma in Liger Kernel LCE forward), Phi3 without LCE * move import to their appropriate conditions * Integrate Phi3 LCE support https://github.com/linkedin/Liger-Kernel/pull/103/ --------- Co-authored-by: Wing Lian --- src/axolotl/integrations/liger/__init__.py | 51 +++++++++++++++++++--- 1 file changed, 44 insertions(+), 7 deletions(-) diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 2a3e95163..d58349932 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -23,7 +23,6 @@ import sys from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.geglu import LigerGEGLUMLP -from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb from liger_kernel.transformers.swiglu import LigerSwiGLUMLP @@ -43,6 +42,9 @@ class LigerPlugin(BasePlugin): def pre_model_load(self, cfg): if cfg.model_config_type == "llama": + from liger_kernel.transformers.model.llama import ( + lce_forward as llama_lce_forward, + ) from transformers.models.llama import modeling_llama if cfg.liger_rope: @@ -57,6 +59,9 @@ class LigerPlugin(BasePlugin): modeling_llama.LlamaForCausalLM.forward = llama_lce_forward elif cfg.model_config_type == "mistral": + from liger_kernel.transformers.model.mistral import ( + lce_forward as mistral_lce_forward, + ) from transformers.models.mistral import modeling_mistral if cfg.liger_rope: @@ -68,11 +73,12 @@ class LigerPlugin(BasePlugin): if cfg.liger_cross_entropy: modeling_mistral.CrossEntropyLoss = LigerCrossEntropyLoss if cfg.liger_fused_linear_cross_entropy: - logging.warning( - "Fused linear cross entropy is not supported for Mistral." - ) + modeling_mistral.MistralForCausalLM.forward = mistral_lce_forward elif cfg.model_config_type == "gemma": + from liger_kernel.transformers.model.gemma import ( + lce_forward as gemma_lce_forward, + ) from transformers.models.gemma import modeling_gemma if cfg.liger_rope: @@ -84,9 +90,7 @@ class LigerPlugin(BasePlugin): if cfg.liger_cross_entropy: modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss if cfg.liger_fused_linear_cross_entropy: - logging.warning( - "Fused linear cross entropy is not supported for Gemma." - ) + modeling_gemma.GemmaForCausalLM.forward = gemma_lce_forward elif cfg.model_config_type == "jamba": from transformers.models.jamba import modeling_jamba @@ -145,3 +149,36 @@ class LigerPlugin(BasePlugin): modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss if cfg.liger_fused_linear_cross_entropy: modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward + + elif cfg.model_config_type == "gemma2": + from transformers.models.gemma2 import modeling_gemma2 + + if cfg.liger_rope: + modeling_gemma2.apply_rotary_pos_emb = liger_rotary_pos_emb + if cfg.liger_rms_norm: + modeling_gemma2.Gemma2RMSNorm = LigerRMSNorm + if cfg.liger_swiglu: + modeling_gemma2.Gemma2MLP = LigerGEGLUMLP + if cfg.liger_cross_entropy: + modeling_gemma2.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + logging.warning( + "Fused linear cross entropy is not supported for Gemma 2." + ) + + elif cfg.model_config_type == "phi3": + from liger_kernel.transformers.model.phi3 import ( + lce_forward as phi3_lce_forward, + ) + from transformers.models.phi3 import modeling_phi3 + + if cfg.liger_rope: + modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb + if cfg.liger_rms_norm: + modeling_phi3.Phi3RMSNorm = LigerRMSNorm + if cfg.liger_swiglu: + modeling_phi3.Phi3MLP = LigerSwiGLUMLP + if cfg.liger_cross_entropy: + modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss + if cfg.liger_fused_linear_cross_entropy: + modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward