Compare commits
1 Commits
fix/cp-was
...
coderabbit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
93600fa80d |
@@ -23,6 +23,29 @@ class LigerPlugin(BasePlugin):
|
||||
return "axolotl.integrations.liger.LigerArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
"""
|
||||
Apply LIGER runtime patches and integrations according to the provided configuration.
|
||||
|
||||
This hook inspects `cfg` and conditionally applies LIGER kernel patches, replacements, and model-specific integrations (rotary embeddings, normalization, GLU variants, and cross-entropy implementations) for the model type indicated by `cfg.model_config_type`. Behavior is driven entirely by the various `cfg.liger_*` flags; the method logs actions and warnings when support is experimental or unavailable.
|
||||
|
||||
Parameters:
|
||||
cfg: Configuration object containing LIGER-related flags and model identification. Expected attributes include:
|
||||
- model_config_type (str): Target model config type to determine which patches to apply.
|
||||
- base_model (str): Base model identifier used when probing model modules (used for some model types).
|
||||
- trust_remote_code (bool|None): Passed when loading remote model code (used for some model types).
|
||||
- torch_compile (bool): If true, disable torch.compile optimizations for certain LIGER kernels.
|
||||
- liger_cross_entropy (bool)
|
||||
- liger_fused_linear_cross_entropy (bool)
|
||||
- liger_use_token_scaling (bool)
|
||||
- liger_rope (bool)
|
||||
- liger_rms_norm (bool)
|
||||
- liger_layer_norm (bool)
|
||||
- liger_glu_activation (str|bool): Name or flag for GLU/SwiGLU activation selection.
|
||||
(Other LIGER flags referenced by the code may also be consulted.)
|
||||
|
||||
Raises:
|
||||
ValueError: If both `cfg.liger_cross_entropy` and `cfg.liger_fused_linear_cross_entropy` are enabled.
|
||||
"""
|
||||
if cfg.torch_compile:
|
||||
# torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled
|
||||
import liger_kernel.ops.fused_linear_cross_entropy
|
||||
@@ -168,6 +191,22 @@ class LigerPlugin(BasePlugin):
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
layer_norm=cfg.liger_layer_norm,
|
||||
)
|
||||
elif cfg.model_config_type == "qwen3_vl":
|
||||
"""
|
||||
Apply Liger kernels for Qwen3 Vision-Language models.
|
||||
|
||||
Note: The parameter 'swiglu' is used instead of 'glu_activation' to match
|
||||
the Liger kernel API for vision-language models.
|
||||
"""
|
||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl
|
||||
|
||||
apply_liger_kernel_to_qwen3_vl(
|
||||
rope=cfg.liger_rope,
|
||||
cross_entropy=cfg.liger_cross_entropy,
|
||||
fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy,
|
||||
rms_norm=cfg.liger_rms_norm,
|
||||
swiglu=cfg.liger_glu_activation, # Note: qwen3_vl uses swiglu parameter name
|
||||
)
|
||||
elif cfg.model_config_type == "qwen3_moe":
|
||||
from axolotl.integrations.liger.models.qwen3_moe import (
|
||||
apply_liger_kernel_to_qwen3_moe,
|
||||
@@ -206,4 +245,4 @@ class LigerPlugin(BasePlugin):
|
||||
else:
|
||||
LOG.warning(
|
||||
f"Unsupported model config type: {cfg.model_config_type}. Liger not applied."
|
||||
)
|
||||
)
|
||||
Reference in New Issue
Block a user