diff --git a/src/axolotl/integrations/liger/plugin.py b/src/axolotl/integrations/liger/plugin.py index ac796c2c9..27dbdf9d4 100644 --- a/src/axolotl/integrations/liger/plugin.py +++ b/src/axolotl/integrations/liger/plugin.py @@ -23,6 +23,29 @@ class LigerPlugin(BasePlugin): return "axolotl.integrations.liger.LigerArgs" def pre_model_load(self, cfg): + """ + Apply LIGER runtime patches and integrations according to the provided configuration. + + This hook inspects `cfg` and conditionally applies LIGER kernel patches, replacements, and model-specific integrations (rotary embeddings, normalization, GLU variants, and cross-entropy implementations) for the model type indicated by `cfg.model_config_type`. Behavior is driven entirely by the various `cfg.liger_*` flags; the method logs actions and warnings when support is experimental or unavailable. + + Parameters: + cfg: Configuration object containing LIGER-related flags and model identification. Expected attributes include: + - model_config_type (str): Target model config type to determine which patches to apply. + - base_model (str): Base model identifier used when probing model modules (used for some model types). + - trust_remote_code (bool|None): Passed when loading remote model code (used for some model types). + - torch_compile (bool): If true, disable torch.compile optimizations for certain LIGER kernels. + - liger_cross_entropy (bool) + - liger_fused_linear_cross_entropy (bool) + - liger_use_token_scaling (bool) + - liger_rope (bool) + - liger_rms_norm (bool) + - liger_layer_norm (bool) + - liger_glu_activation (str|bool): Name or flag for GLU/SwiGLU activation selection. + (Other LIGER flags referenced by the code may also be consulted.) + + Raises: + ValueError: If both `cfg.liger_cross_entropy` and `cfg.liger_fused_linear_cross_entropy` are enabled. + """ if cfg.torch_compile: # torch compile will unnecessarily attempt to optimize the triton kernel unless explicitly disabled import liger_kernel.ops.fused_linear_cross_entropy @@ -168,6 +191,22 @@ class LigerPlugin(BasePlugin): rms_norm=cfg.liger_rms_norm, layer_norm=cfg.liger_layer_norm, ) + elif cfg.model_config_type == "qwen3_vl": + """ + Apply Liger kernels for Qwen3 Vision-Language models. + + Note: The parameter 'swiglu' is used instead of 'glu_activation' to match + the Liger kernel API for vision-language models. + """ + from liger_kernel.transformers import apply_liger_kernel_to_qwen3_vl + + apply_liger_kernel_to_qwen3_vl( + rope=cfg.liger_rope, + cross_entropy=cfg.liger_cross_entropy, + fused_linear_cross_entropy=cfg.liger_fused_linear_cross_entropy, + rms_norm=cfg.liger_rms_norm, + swiglu=cfg.liger_glu_activation, # Note: qwen3_vl uses swiglu parameter name + ) elif cfg.model_config_type == "qwen3_moe": from axolotl.integrations.liger.models.qwen3_moe import ( apply_liger_kernel_to_qwen3_moe, @@ -206,4 +245,4 @@ class LigerPlugin(BasePlugin): else: LOG.warning( f"Unsupported model config type: {cfg.model_config_type}. Liger not applied." - ) + ) \ No newline at end of file