diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 9183ae309..a64d748c6 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -30,8 +30,11 @@ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP from axolotl.integrations.base import BasePlugin +from ...utils.distributed import zero_only from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 +LOG = logging.getLogger("axolotl.integrations.liger") + class LigerPlugin(BasePlugin): """ @@ -62,6 +65,10 @@ class LigerPlugin(BasePlugin): kwargs["geglu"] = cfg.liger_glu_activation elif "swiglu" in liger_fn_sig.parameters: kwargs["swiglu"] = cfg.liger_glu_activation + with zero_only(): + LOG.info( + f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}" + ) apply_liger_fn(**kwargs) elif cfg.model_config_type == "jamba": from transformers.models.jamba import modeling_jamba