* get_logger use_environ fix * rethinking * replacing old logger imports * simplify * fix boolean cond
188 lines
8.3 KiB
Python
188 lines
8.3 KiB
Python
# Copyright 2024 Axolotl AI. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Module for the Plugin for LIGER integraton with Axolotl.
|
|
|
|
Liger Kernel is the collection of Triton-native kernels for LLM Training.
|
|
It is designed to be performant, correct, and light-weight.
|
|
"""
|
|
import inspect
|
|
import sys
|
|
|
|
from axolotl.integrations.base import BasePlugin
|
|
from axolotl.utils.logging import get_logger
|
|
|
|
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
|
from .utils import patch_with_compile_disable
|
|
|
|
LOG = get_logger(__name__)
|
|
|
|
|
|
class LigerPlugin(BasePlugin):
|
|
"""
|
|
Plugin for LIGER integraton with Axolotl.
|
|
"""
|
|
|
|
def get_input_args(self):
|
|
return "axolotl.integrations.liger.LigerArgs"
|
|
|
|
def pre_model_load(self, cfg):
|
|
if cfg.torch_compile:
|
|
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
|
|
import liger_kernel.ops.fused_linear_cross_entropy
|
|
|
|
patch_with_compile_disable(
|
|
liger_kernel.ops.fused_linear_cross_entropy,
|
|
"fused_linear_cross_entropy_forward",
|
|
)
|
|
patch_with_compile_disable(
|
|
liger_kernel.ops.fused_linear_cross_entropy,
|
|
"fused_linear_cross_entropy_backward",
|
|
)
|
|
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
|
from liger_kernel.transformers.functional import liger_cross_entropy
|
|
from liger_kernel.transformers.layer_norm import LigerLayerNorm
|
|
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
|
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
|
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
|
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|
|
|
if cfg.liger_cross_entropy and cfg.liger_fused_linear_cross_entropy:
|
|
raise ValueError(
|
|
"Cannot have both `liger_cross_entropy` and `liger_fused_linear_cross_entropy` set."
|
|
)
|
|
|
|
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
|
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
|
liger_fn_sig = inspect.signature(apply_liger_fn)
|
|
kwargs = {}
|
|
if "rope" in liger_fn_sig.parameters:
|
|
kwargs["rope"] = cfg.liger_rope
|
|
if "cross_entropy" in liger_fn_sig.parameters:
|
|
kwargs["cross_entropy"] = cfg.liger_cross_entropy
|
|
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
|
|
kwargs["fused_linear_cross_entropy"] = (
|
|
cfg.liger_fused_linear_cross_entropy
|
|
)
|
|
if "rms_norm" in liger_fn_sig.parameters:
|
|
kwargs["rms_norm"] = cfg.liger_rms_norm
|
|
if "layer_norm" in liger_fn_sig.parameters:
|
|
kwargs["layer_norm"] = cfg.liger_layer_norm
|
|
if "geglu" in liger_fn_sig.parameters:
|
|
kwargs["geglu"] = cfg.liger_glu_activation
|
|
elif "swiglu" in liger_fn_sig.parameters:
|
|
kwargs["swiglu"] = cfg.liger_glu_activation
|
|
LOG.info(f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}")
|
|
apply_liger_fn(**kwargs)
|
|
elif cfg.model_config_type == "jamba":
|
|
from transformers.models.jamba import modeling_jamba
|
|
|
|
from .models.jamba import lce_forward as jamba_lce_forward
|
|
|
|
if cfg.liger_rope:
|
|
modeling_jamba.apply_rotary_pos_emb = liger_rotary_pos_emb
|
|
if cfg.liger_rms_norm:
|
|
modeling_jamba.JambaRMSNorm = LigerRMSNorm
|
|
if cfg.liger_glu_activation:
|
|
modeling_jamba.JambaMLP = LigerSwiGLUMLP
|
|
if cfg.liger_layer_norm:
|
|
modeling_jamba.nn.LayerNorm = LigerLayerNorm
|
|
if cfg.liger_cross_entropy:
|
|
from transformers.loss.loss_utils import nn
|
|
|
|
nn.functional.cross_entropy = liger_cross_entropy
|
|
if cfg.liger_fused_linear_cross_entropy:
|
|
modeling_jamba.JambaForCausalLM.forward = jamba_lce_forward
|
|
elif cfg.model_config_type == "deepseek_v2":
|
|
from accelerate import init_empty_weights
|
|
from transformers import AutoModelForCausalLM
|
|
|
|
with init_empty_weights():
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
cfg.base_model, trust_remote_code=cfg.trust_remote_code or False
|
|
)
|
|
modeling_mod = sys.modules[model.__class__.__module__]
|
|
|
|
from .models.deepseekv2 import lce_forward as deepseekv2_lce_forward
|
|
|
|
if cfg.liger_rope:
|
|
# The DeepseekV2 version of RoPE is different than upstream LLaMA.
|
|
# See https://github.com/linkedin/Liger-Kernel/issues/129#issuecomment-2313763528
|
|
LOG.warning("Fused liger_rope is not supported for DeepseekV2.")
|
|
if cfg.liger_glu_activation:
|
|
LOG.warning("liger_glu_activation is not supported for DeepseekV2.")
|
|
if cfg.liger_rms_norm:
|
|
modeling_mod.DeepseekV2RMSNorm = LigerRMSNorm
|
|
if cfg.liger_glu_activation:
|
|
modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
|
|
if cfg.liger_layer_norm:
|
|
modeling_mod.DeepseekV2MLP.forward = LigerLayerNorm.forward
|
|
if cfg.liger_cross_entropy:
|
|
# We do not patch `nn.functional.cross_entropy` for DeepseekV2 as it still uses
|
|
# nn.CrossEntropyLoss in the forward method.
|
|
modeling_mod.CrossEntropyLoss = LigerCrossEntropyLoss
|
|
if cfg.liger_fused_linear_cross_entropy:
|
|
modeling_mod.DeepseekV2ForCausalLM.forward = deepseekv2_lce_forward
|
|
elif cfg.model_config_type == "llama4":
|
|
from axolotl.integrations.liger.models.llama4 import (
|
|
apply_liger_kernel_to_llama4,
|
|
)
|
|
|
|
apply_liger_kernel_to_llama4(
|
|
cross_entropy=cfg.liger_cross_entropy,
|
|
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
|
glu_activation=cfg.liger_glu_activation,
|
|
rms_norm=cfg.liger_rms_norm,
|
|
layer_norm=cfg.liger_layer_norm,
|
|
)
|
|
elif cfg.model_config_type == "qwen3":
|
|
from axolotl.integrations.liger.models.qwen3 import (
|
|
apply_liger_kernel_to_qwen3,
|
|
)
|
|
|
|
apply_liger_kernel_to_qwen3(
|
|
cross_entropy=cfg.liger_cross_entropy,
|
|
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
|
glu_activation=cfg.liger_glu_activation,
|
|
rms_norm=cfg.liger_rms_norm,
|
|
layer_norm=cfg.liger_layer_norm,
|
|
)
|
|
elif cfg.model_config_type == "qwen3_moe":
|
|
from axolotl.integrations.liger.models.qwen3_moe import (
|
|
apply_liger_kernel_to_qwen3_moe,
|
|
)
|
|
|
|
apply_liger_kernel_to_qwen3_moe(
|
|
cross_entropy=cfg.liger_cross_entropy,
|
|
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
|
glu_activation=cfg.liger_glu_activation,
|
|
rms_norm=cfg.liger_rms_norm,
|
|
layer_norm=cfg.liger_layer_norm,
|
|
)
|
|
elif cfg.model_config_type == "granitemoe":
|
|
from liger_kernel.transformers import apply_liger_kernel_to_granite
|
|
|
|
apply_liger_kernel_to_granite(
|
|
rope=cfg.liger_rope,
|
|
cross_entropy=cfg.liger_cross_entropy,
|
|
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
|
rms_norm=cfg.liger_rms_norm,
|
|
swiglu=cfg.liger_glu_activation,
|
|
)
|
|
else:
|
|
LOG.warning(
|
|
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
|
)
|