add logging

This commit is contained in:
Wing Lian
2024-10-15 12:10:35 -04:00
committed by NanoCode012
parent 415399b565
commit c2a48c3a1e

View File

@@ -30,8 +30,11 @@ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
LOG = logging.getLogger("axolotl.integrations.liger")
class LigerPlugin(BasePlugin):
"""
@@ -62,6 +65,10 @@ class LigerPlugin(BasePlugin):
kwargs["geglu"] = cfg.liger_glu_activation
elif "swiglu" in liger_fn_sig.parameters:
kwargs["swiglu"] = cfg.liger_glu_activation
with zero_only():
LOG.info(
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
)
apply_liger_fn(**kwargs)
elif cfg.model_config_type == "jamba":
from transformers.models.jamba import modeling_jamba