""" Grokfast plugin for Axolotl """ import logging from transformers.trainer_callback import TrainerCallback from ..base import BasePlugin from .args import GrokfastArgs # pylint: disable=unused-import. # noqa: F401 from .optimizer import gradfilter_ema LOG = logging.getLogger("axolotl.integrations.grokfast") class GrokfastCallbackHandler(TrainerCallback): """ Transformer trainer callbacks for Grokfast """ def __init__(self, *args_, alpha=0.98, lamb=2.0, **kwargs): super().__init__(*args_, **kwargs) self.grads = None self.alpha = alpha self.lamb = lamb def on_train_begin(self, *args_, **kwargs): # pylint: disable=unused-argument self.grads = None def on_pre_optimizer_step( self, args_, state, control, **kwargs ): # pylint: disable=unused-argument model = kwargs.pop("model") self.grads = gradfilter_ema(model, self.grads, alpha=self.alpha, lamb=self.lamb) return control class GrokfastPlugin(BasePlugin): """ Plugin for Grokfast optimizer integraton with Axolotl. """ def get_input_args(self): return "axolotl.integrations.grokfast.GrokfastArgs" def add_callbacks_post_trainer(self, cfg, trainer): LOG.info("Adding Grokfast callback to the trainer") callback = GrokfastCallbackHandler( alpha=cfg.grokfast_alpha, lamb=cfg.grokfast_lamb ) return [callback]