add logging
This commit is contained in:
@@ -30,8 +30,11 @@ from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
|||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
|
||||||
|
from ...utils.distributed import zero_only
|
||||||
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import LigerArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.integrations.liger")
|
||||||
|
|
||||||
|
|
||||||
class LigerPlugin(BasePlugin):
|
class LigerPlugin(BasePlugin):
|
||||||
"""
|
"""
|
||||||
@@ -62,6 +65,10 @@ class LigerPlugin(BasePlugin):
|
|||||||
kwargs["geglu"] = cfg.liger_glu_activation
|
kwargs["geglu"] = cfg.liger_glu_activation
|
||||||
elif "swiglu" in liger_fn_sig.parameters:
|
elif "swiglu" in liger_fn_sig.parameters:
|
||||||
kwargs["swiglu"] = cfg.liger_glu_activation
|
kwargs["swiglu"] = cfg.liger_glu_activation
|
||||||
|
with zero_only():
|
||||||
|
LOG.info(
|
||||||
|
f"Applying LIGER to {cfg.model_config_type} with kwargs: {kwargs}"
|
||||||
|
)
|
||||||
apply_liger_fn(**kwargs)
|
apply_liger_fn(**kwargs)
|
||||||
elif cfg.model_config_type == "jamba":
|
elif cfg.model_config_type == "jamba":
|
||||||
from transformers.models.jamba import modeling_jamba
|
from transformers.models.jamba import modeling_jamba
|
||||||
|
|||||||
Reference in New Issue
Block a user