From c2a48c3a1eb91fcd57c3fa3e1c91e7ab81a462b2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 15 Oct 2024 12:10:35 -0400 Subject: [PATCH] add logging --- src/axolotl/integrations/liger/__init__.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/axolotl/integrations/liger/__init__.py b/src/axolotl/integrations/liger/__init__.py index 9183ae309..a64d748c6 100644 --- a/src/axolotl/integrations/liger/__init__.py +++ b/src/axolotl/integrations/liger/__init__.py @@ -30,8 +30,11 @@ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP from axolotl.integrations.base import BasePlugin +from ...utils.distributed import zero_only from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401 +LOG = logging.getLogger("axolotl.integrations.liger") + class LigerPlugin(BasePlugin): """ @@ -62,6 +65,10 @@ class LigerPlugin(BasePlugin): kwargs["geglu"] = cfg.liger_glu_activation elif "swiglu" in liger_fn_sig.parameters: kwargs["swiglu"] = cfg.liger_glu_activation + with zero_only(): + LOG.info( + f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}" + ) apply_liger_fn(**kwargs) elif cfg.model_config_type == "jamba": from transformers.models.jamba import modeling_jamba